diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h index a923d18ce6..f8cc29c900 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h @@ -1,14 +1,20 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_KWARG_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_KWARG_DATAFLOW_GRAPH_H +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" #include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.h" #include "utils/bidict/algorithms/bidict_from_enumerating.h" #include "utils/containers/enumerate.h" +#include "utils/containers/generate_map.h" #include "utils/containers/sorted.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" #include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -41,6 +47,43 @@ V1KwargDataflowGraph }; } +template +std::pair, + std::unordered_map> + from_v1_including_node_numbering(V1KwargDataflowGraph const &v1) { + std::unordered_map node_map = + generate_map(v1.nodes, [](nonnegative_int n) { + return Node{n.size_t_from_nonnegative_int()}; + }); + std::unordered_set node_set = unordered_set_of(values(node_map)); + + std::unordered_set> edges = + transform(v1.edges, [](V1GraphEdge const &e) { + Node srcNode = Node{e.srcNode.size_t_from_nonnegative_int()}; + Node dstNode = Node{e.dstNode.size_t_from_nonnegative_int()}; + return OpenKwargDataflowEdge{KwargDataflowEdge{ + /*src=*/KwargDataflowOutput{srcNode, e.srcSlot}, + /*dst=*/KwargDataflowInput{dstNode, e.dstSlot}, + }}; + }); + + OpenKwargDataflowGraphData graph_data = + OpenKwargDataflowGraphData{ + /*nodes=*/node_set, + /*edges=*/edges, + /*inputs=*/{}, + /*outputs=*/{}, + }; + return std::pair{view_from_open_kwarg_dataflow_graph_data(graph_data), + node_map}; +} + +template +KwargDataflowGraphView + from_v1(V1KwargDataflowGraph const &v1) { + return from_v1_including_node_numbering(v1).first; +} + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h index dbe660c3a6..9fe2d53db1 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h @@ -4,9 +4,11 @@ #include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h" #include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.h" #include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h" #include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -50,6 +52,29 @@ V1LabelledKwargDataflowGraph to_v1( return to_v1_including_node_numbering(g).first; } +template +std::pair, + std::unordered_map> + from_v1_including_node_numbering( + V1LabelledKwargDataflowGraph const + &v1) { + auto [graph_view, node_map] = from_v1_including_node_numbering(v1.graph); + + std::unordered_map node_labels = map_keys( + v1.node_labels, [&](nonnegative_int n) { return node_map.at(n); }); + std::unordered_map, OutputLabel> value_labels; + + return std::pair{kwarg_dataflow_graph_view_with_labelling( + graph_view, node_labels, value_labels), + node_map}; +} + +template +LabelledKwargDataflowGraphView from_v1( + V1LabelledKwargDataflowGraph const &v1) { + return from_v1_including_node_numbering(v1).first; +} + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h index c0e9966425..a3ddbdf7be 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h @@ -12,6 +12,8 @@ V1ComputationGraph to_v1(ComputationGraph const &); std::pair> to_v1_including_node_numbering(ComputationGraph const &); +ComputationGraph from_v1(V1ComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.dtg.toml new file mode 100644 index 0000000000..2e4300745d --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "V1MappedOperatorTaskGroup" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/machine_space_coordinate.dtg.h", + "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "shard_bindings" +type = "::FlexFlow::bidict<::FlexFlow::MachineSpaceCoordinate, ::FlexFlow::OperatorAtomicTaskShardBinding>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h new file mode 100644 index 0000000000..8e386e156f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_OPERATOR_TASK_GROUP_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_OPERATOR_TASK_GROUP_H + +#include "pcg/file_format/v1/v1_mapped_operator_task_group.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" + +namespace FlexFlow { + +V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &); +MappedOperatorTaskGroup from_v1(V1MappedOperatorTaskGroup const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.toml new file mode 100644 index 0000000000..8dc336e4ea --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "V1MappedParallelComputationGraph" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "", + "pcg/file_format/v1/v1_parallel_computation_graph.dtg.h", + "pcg/file_format/v1/v1_mapped_operator_task_group.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::V1ParallelComputationGraph" + +[[fields]] +name = "mapped_tasks" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::V1MappedOperatorTaskGroup>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h new file mode 100644 index 0000000000..f78efc4591 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +V1MappedParallelComputationGraph to_v1(MappedParallelComputationGraph const &); +MappedParallelComputationGraph + from_v1(V1MappedParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h index aceb59f5af..1ec9ee0e8c 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h @@ -8,6 +8,8 @@ namespace FlexFlow { V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); +ParallelComputationGraph from_v1(V1ParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc index 9e4a46b87a..cc10bbf4cb 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc @@ -12,4 +12,11 @@ template V1KwargDataflowGraph to_v1(KwargDataflowGraphView const &, std::unordered_map const &); +template std::pair, + std::unordered_map> + from_v1_including_node_numbering(V1KwargDataflowGraph const &); + +template KwargDataflowGraphView + from_v1(V1KwargDataflowGraph const &); + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc index 4e7b9b651f..4e50949e3f 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc @@ -18,4 +18,14 @@ template std::pair< template V1LabelledKwargDataflowGraph to_v1( LabelledKwargDataflowGraphView const &); +template std::pair< + LabelledKwargDataflowGraphView, + std::unordered_map> + from_v1_including_node_numbering( + V1LabelledKwargDataflowGraph const &); + +template LabelledKwargDataflowGraphView + from_v1( + V1LabelledKwargDataflowGraph const &); + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc index 852ca73a36..e52b5708e5 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -1,6 +1,8 @@ #include "pcg/file_format/v1/v1_computation_graph.h" #include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h" #include "utils/bidict/algorithms/transform_values.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" namespace FlexFlow { @@ -25,4 +27,15 @@ std::pair> return {v1_cg, v1_node_ids}; } +ComputationGraph from_v1(V1ComputationGraph const &v1) { + return ComputationGraph{ + LabelledKwargDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenKwargDataflowGraph>( + from_v1(v1.raw_graph))}; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc b/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc new file mode 100644 index 0000000000..465dd01fb6 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/v1_mapped_operator_task_group.h" + +namespace FlexFlow { + +V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &g) { + return V1MappedOperatorTaskGroup{g.get_shard_bindings()}; +} + +MappedOperatorTaskGroup from_v1(V1MappedOperatorTaskGroup const &v1) { + return MappedOperatorTaskGroup{v1.shard_bindings}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc new file mode 100644 index 0000000000..0236a8834c --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc @@ -0,0 +1,26 @@ +#include "pcg/file_format/v1/v1_mapped_parallel_computation_graph.h" +#include "pcg/file_format/v1/v1_mapped_operator_task_group.h" +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "utils/containers/map_values.h" + +namespace FlexFlow { + +V1MappedParallelComputationGraph + to_v1(MappedParallelComputationGraph const &mpcg) { + return V1MappedParallelComputationGraph{ + to_v1(mpcg.pcg), + map_values(mpcg.mapped_tasks, + [](MappedOperatorTaskGroup const &g) { return to_v1(g); }), + }; +} + +MappedParallelComputationGraph + from_v1(V1MappedParallelComputationGraph const &v1) { + return MappedParallelComputationGraph{ + from_v1(v1.pcg), + map_values(v1.mapped_tasks, + [](V1MappedOperatorTaskGroup const &g) { return from_v1(g); }), + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc index e14d15d66a..a5afa3ebdc 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -1,5 +1,7 @@ #include "pcg/file_format/v1/v1_parallel_computation_graph.h" #include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" namespace FlexFlow { @@ -10,4 +12,17 @@ V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { }; } +ParallelComputationGraph from_v1(V1ParallelComputationGraph const &v1) { + return ParallelComputationGraph{ + LabelledKwargDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenKwargDataflowGraph>( + from_v1(v1.raw_graph))}; +} + } // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc index 7af3f648d9..2ae643bd0f 100644 --- a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc @@ -1,6 +1,8 @@ #include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" #include +#include using namespace ::FlexFlow; @@ -25,6 +27,14 @@ TEST_SUITE(FF_TEST_SUITE) { }(); V1ComputationGraph v1_cg = to_v1(cg); - nlohmann::json j = v1_cg; + + SUBCASE("serializes to JSON") { + nlohmann::json j = v1_cg; + } + + SUBCASE("round-trips via from_v1") { + ComputationGraph result = from_v1(v1_cg); + CHECK(computation_graphs_are_isomorphic(cg, result)); + } } } diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc new file mode 100644 index 0000000000..78da5430b7 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc @@ -0,0 +1,81 @@ +#include "pcg/file_format/v1/v1_mapped_parallel_computation_graph.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/file_format/v1/v1_mapped_operator_task_group.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" +#include "utils/bidict/bidict.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1MappedParallelComputationGraph") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 16_p, + }, + }, + DataType::FLOAT, + }; + + ParallelLayerAddedResult result = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t layer = result.parallel_layer; + + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }; + + OperatorAtomicTaskShardBinding binding = OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::OUTPUT, + ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_components=*/FFOrdered{0_n, 0_n}, + }, + }, + }, + }; + + MappedOperatorTaskGroup task_group = MappedOperatorTaskGroup{ + bidict{ + {coord, binding}, + }, + }; + + MappedParallelComputationGraph mpcg = MappedParallelComputationGraph{ + /*pcg=*/pcg, + /*mapped_tasks=*/{{layer, task_group}}, + }; + + V1MappedParallelComputationGraph v1_mpcg = to_v1(mpcg); + + SUBCASE("serializes to JSON") { + nlohmann::json j = v1_mpcg; + } + + SUBCASE("MappedOperatorTaskGroup round-trips via from_v1") { + MappedOperatorTaskGroup result = from_v1(to_v1(task_group)); + CHECK(result == task_group); + } + + SUBCASE("MappedParallelComputationGraph round-trips via from_v1") { + MappedParallelComputationGraph result = from_v1(v1_mpcg); + CHECK(pcgs_are_isomorphic(mpcg.pcg, result.pcg)); + } + } +} diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc index ec6a4ab006..033626ab5c 100644 --- a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -1,6 +1,8 @@ #include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include +#include using namespace ::FlexFlow; @@ -29,6 +31,14 @@ TEST_SUITE(FF_TEST_SUITE) { }(); V1ParallelComputationGraph v1_pcg = to_v1(pcg); - nlohmann::json j = v1_pcg; + + SUBCASE("serializes to JSON") { + nlohmann::json j = v1_pcg; + } + + SUBCASE("round-trips via from_v1") { + ParallelComputationGraph result = from_v1(v1_pcg); + CHECK(pcgs_are_isomorphic(pcg, result)); + } } } diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h new file mode 100644 index 0000000000..782e63889b --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct KwargDataflowGraphLabellingWrapper final + : public ILabelledKwargDataflowGraphView { +public: + KwargDataflowGraphLabellingWrapper() = delete; + KwargDataflowGraphLabellingWrapper( + KwargDataflowGraphView const &unlabelled, + std::unordered_map const &node_labels, + std::unordered_map, OutputLabel> const + &output_labels) + : unlabelled(unlabelled), node_labels(node_labels), + output_labels(output_labels) {} + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->unlabelled.query_nodes(q); + } + + std::unordered_set> + query_edges(KwargDataflowEdgeQuery const &q) const override { + return this->unlabelled.query_edges(q); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return this->unlabelled.query_outputs(q); + } + + NodeLabel at(Node const &n) const override { + return this->node_labels.at(n); + } + + OutputLabel at(KwargDataflowOutput const &v) const override { + return this->output_labels.at(v); + } + + KwargDataflowGraphLabellingWrapper *clone() const override { + return new KwargDataflowGraphLabellingWrapper{ + this->unlabelled, + this->node_labels, + this->output_labels, + }; + } + +private: + KwargDataflowGraphView unlabelled; + std::unordered_map node_labels; + std::unordered_map, OutputLabel> output_labels; +}; + +template +LabelledKwargDataflowGraphView + kwarg_dataflow_graph_view_with_labelling( + KwargDataflowGraphView const &g, + std::unordered_map const &node_labels, + std::unordered_map, OutputLabel> const + &value_labels) { + return LabelledKwargDataflowGraphView:: + template create< + KwargDataflowGraphLabellingWrapper>( + g, node_labels, value_labels); +} +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h index c775cfc9ed..1972cc6786 100644 --- a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h +++ b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h @@ -17,6 +17,7 @@ class nonnegative_int { explicit nonnegative_int(unsigned long long int value); explicit operator int() const noexcept; + explicit operator size_t() const noexcept; bool operator<(nonnegative_int const &other) const; bool operator==(nonnegative_int const &other) const; @@ -56,6 +57,9 @@ class nonnegative_int { nonnegative_int operator%(nonnegative_int const &other) const; nonnegative_int &operator%=(nonnegative_int const &other); + int int_from_nonnegative_int() const; + size_t size_t_from_nonnegative_int() const; + friend std::ostream &operator<<(std::ostream &os, nonnegative_int const &n); friend int format_as(nonnegative_int const &); diff --git a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc index 7593a8e9ec..8d1c4383a9 100644 --- a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc +++ b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc @@ -24,6 +24,10 @@ nonnegative_int::operator int() const noexcept { return this->value_; } +nonnegative_int::operator size_t() const noexcept { + return static_cast(this->value_); +} + bool nonnegative_int::operator<(nonnegative_int const &other) const { return this->value_ < other.value_; } @@ -151,6 +155,14 @@ nonnegative_int &nonnegative_int::operator%=(nonnegative_int const &other) { return *this; } +int nonnegative_int::int_from_nonnegative_int() const { + return this->value_; +} + +size_t nonnegative_int::size_t_from_nonnegative_int() const { + return static_cast(this->value_); +} + std::ostream &operator<<(std::ostream &os, nonnegative_int const &n) { os << n.value_; return os; diff --git a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc index 58fb151313..8c5ecd3e2c 100644 --- a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc +++ b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc @@ -315,6 +315,24 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } + TEST_CASE("nonnegative_int::int_from_nonnegative_int()") { + nonnegative_int input = nonnegative_int{3}; + + int result = input.int_from_nonnegative_int(); + int correct = 3; + + CHECK(result == correct); + } + + TEST_CASE("nonnegative_int::size_t_from_nonnegative_int()") { + nonnegative_int input = nonnegative_int{3}; + + size_t result = input.size_t_from_nonnegative_int(); + size_t correct = 3; + + CHECK(result == correct); + } + TEST_CASE("adl_serializer") { SUBCASE("to_json") { nonnegative_int input = nonnegative_int{5};