From d269b406b681fcffc4cac3f1759fd1e41633c342 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 16 Oct 2023 16:02:34 -0400 Subject: [PATCH 01/37] compiler build --- lib/compiler/CMakeLists.txt | 1 + .../include/compiler/machine_mapping.h | 24 +- .../include/compiler/unity_algorithm.h | 17 +- lib/compiler/src/graph_utils.cc | 2 +- lib/compiler/src/machine_mapping.cc | 126 +- lib/compiler/src/old/basic_graph.h | 158 - lib/compiler/src/old/dominators.h | 494 --- lib/compiler/src/old/graph.cc | 1255 ------ lib/compiler/src/old/graph.h | 248 -- lib/compiler/src/old/graph_structures.h | 269 -- lib/compiler/src/old/node.h | 47 - .../src/old/parallel_dim_mapping_record.h | 4 - lib/compiler/src/old/search_helper.cc | 525 --- lib/compiler/src/old/search_helper.h | 122 - lib/compiler/src/old/simplification.cc | 189 - lib/compiler/src/old/simplification.h | 34 - lib/compiler/src/old/split_types.cc | 36 - lib/compiler/src/old/split_types.h | 32 - lib/compiler/src/old/substitution.cc | 3733 ----------------- lib/compiler/src/old/substitution.h | 309 -- lib/compiler/src/unity_algorithm.cc | 16 +- lib/pcg/include/pcg/machine_view.h | 5 +- .../include/pcg/parallel_computation_graph.h | 10 + .../include/substitutions/substitution.h | 8 + lib/utils/include/utils/graph/algorithms.h | 3 + .../utils/graph/labelled/node_labelled.h | 1 + .../utils/graph/labelled/node_labelled_open.h | 8 +- .../include/utils/graph/labelled/open_views.h | 10 + .../utils/graph/labelled/output_labelled.h | 5 +- .../graph/labelled/output_labelled_open.h | 10 +- .../include/utils/graph/labelled/views.h | 3 +- 31 files changed, 137 insertions(+), 7567 deletions(-) delete mode 100644 lib/compiler/src/old/basic_graph.h delete mode 100644 lib/compiler/src/old/dominators.h delete mode 100644 lib/compiler/src/old/graph.cc delete mode 100644 lib/compiler/src/old/graph.h delete mode 100644 lib/compiler/src/old/graph_structures.h delete mode 100644 lib/compiler/src/old/node.h delete mode 100644 lib/compiler/src/old/parallel_dim_mapping_record.h delete mode 100644 lib/compiler/src/old/search_helper.cc delete mode 100644 lib/compiler/src/old/search_helper.h delete mode 100644 lib/compiler/src/old/simplification.cc delete mode 100644 lib/compiler/src/old/simplification.h delete mode 100644 lib/compiler/src/old/split_types.cc delete mode 100644 lib/compiler/src/old/split_types.h delete mode 100644 lib/compiler/src/old/substitution.cc delete mode 100644 lib/compiler/src/old/substitution.h diff --git a/lib/compiler/CMakeLists.txt b/lib/compiler/CMakeLists.txt index daa96b08bc..45c369fcdf 100644 --- a/lib/compiler/CMakeLists.txt +++ b/lib/compiler/CMakeLists.txt @@ -14,6 +14,7 @@ ff_add_library( optional pcg spdlog + substitutions ) add_subdirectory(ffi) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 4089260735..e8d7457fbf 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -5,10 +5,12 @@ #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" namespace FlexFlow { +using SubParallelComputationGraphView = OutputLabelledOpenMultiDiGraphView; + struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); static bool nodes_are_disjoint(MachineMapping const &m1, @@ -20,14 +22,13 @@ FF_VISITABLE_STRUCT(MachineMapping, machine_views); struct OptimalCostState { SerialParallelDecomposition subgraph; - MachineSpecification resource; - req> source_machine_view, sink_machine_view; + req resource; + // req> given_machine_views; + // req> frontier_machine_views; }; FF_VISITABLE_STRUCT(OptimalCostState, subgraph, - resource, - source_machine_view, - sink_machine_view); + resource); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, @@ -37,7 +38,7 @@ struct OptimalCostResult { static OptimalCostResult infinity(); float runtime; - MachineMapping machine_mapping; + req machine_mapping; }; FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); @@ -67,4 +68,13 @@ OptimalCostResult } // namespace FlexFlow +namespace std { + +template <> +struct hash> { + size_t operator()(std::unordered_map const &g) const; +}; + +}; + #endif diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 57f1c8c063..fc068d48c5 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -4,18 +4,15 @@ #include "cost_estimate.h" #include "machine_mapping.h" #include "pcg/computation_graph.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" namespace FlexFlow { -struct Substitution {}; - struct Strategy { ParallelComputationGraph pcg; MachineMapping machine_mapping; req runtime; }; -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); struct StrategyRuntimeCmp { bool operator()(Strategy const &, Strategy const &); @@ -30,7 +27,7 @@ struct OptimizerConfig { Strategy graph_optimize(ComputationGraph &cg, - ICostEstimator const &cost_estimator, + CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( Operator const &, MachineSpecification const &)> const @@ -39,4 +36,14 @@ Strategy } // namespace FlexFlow +VISITABLE_STRUCT(FlexFlow::Strategy, pcg, machine_mapping, runtime); +namespace std { + +template <> +struct hash { + size_t operator()(FlexFlow::Strategy const &) const; +}; + +}; + #endif diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 4f22490ffa..d7f15e0796 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -4,7 +4,7 @@ namespace FlexFlow { SerialParallelDecomposition get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { - return get_serial_parallel_decomposition(as_digraph(pcg)); + return get_serial_parallel_decomposition(pcg.value()); } std::vector diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 2f6af8a62b..fb04f57eac 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -45,9 +45,12 @@ bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, optional OptimalCostCache::load(OptimalCostState const &state) const { - if (contains_key(cache, state)) { - return make_optional(cache.at(state)); - } + auto it = cache.find(state); + // if (contains_key(cache, state)) { + // // auto result = cache.at(state); + // OptimalCostResult result = OptimalCostResult::infinity(); + // return make_optional(result); + // } return nullopt; } @@ -88,31 +91,10 @@ GraphSplit return {get_nodes(pre_decomposition), get_nodes(post_decomposition)}; } -std::pair - apply_split(SubParallelComputationGraph const &g, GraphSplit const &split) { - OpenMultiDiGraphView g1 = get_subgraph(g, split.first); - OpenMultiDiGraphView g2 = get_subgraph(g, split.second); - - if (get_edge_splits(g, split).size() > 0) { - // Sequential split - if (get_open_sinks(g1).size() <= get_open_sources(g2).size()) { - // get_open_sinks(*g1).size() should be 1 in perfect sp graphs - return {get_subgraph(g, split.first), - get_subgraph(g, split.second)}; - } else { - return {get_subgraph(g, split.first), - get_subgraph(g, split.first)}; - } - } else { - // Parallel split - return {get_subgraph(g, split.first), - get_subgraph(g, split.second)}; - } -} - -float estimate_cost(SubParallelComputationGraph const &g, +float estimate_cost(SubParallelComputationGraphView const &g, CostEstimator const &estimator, - MachineMapping const &device_mapping) { + MachineMapping const &device_mapping, + std::unordered_map const &frontier_machine_views) { NOT_IMPLEMENTED(); } @@ -122,26 +104,26 @@ void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { struct OptimalCost { OptimalCost( - SubParallelComputationGraph const &g, + SubParallelComputationGraphView const &g, CostEstimator const &cost_estimator, MachineSpecification const &resource, - optional const &source_machine_view, // assume perfect SP - optional const &sink_machine_view, + std::unordered_map const &given_machine_views, + std::unordered_map const &frontier_machine_views, std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views, OptimalCostCache &cached_subgraph_costs) : g(g), cost_estimator(cost_estimator), resource(resource), - source_machine_view(source_machine_view), - sink_machine_view(sink_machine_view), + given_machine_views(restrict_keys(given_machine_views, get_nodes(g))), + frontier_machine_views(restrict_keys(frontier_machine_views, get_edges(g))), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} - SubParallelComputationGraph const &g; + SubParallelComputationGraphView const &g; CostEstimator const &cost_estimator; MachineSpecification const &resource; - optional const &source_machine_view; - optional const &sink_machine_view; + std::unordered_map const &given_machine_views; + std::unordered_map const &frontier_machine_views; std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views; @@ -149,7 +131,7 @@ struct OptimalCost { template OptimalCostResult operator()(T const &t) const { - OptimalCostState state{g, resource, source_machine_view, sink_machine_view}; + OptimalCostState state{t, resource/*, given_machine_views, frontier_machine_views*/}; optional cached_result = cached_subgraph_costs.load(state); @@ -168,44 +150,40 @@ struct OptimalCost { SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; - auto subgraphs = apply_split(g, get_graph_split(pre_decompn, post_decompn)); - SubParallelComputationGraph pre_graph = subgraphs.first, - post_graph = subgraphs.second; + GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); + SubParallelComputationGraphView pre_graph = get_subgraph(g, graph_split.first); + SubParallelComputationGraphView post_graph = get_subgraph(g, graph_split.second); - std::unordered_set pre_graph_sinks = get_closed_sinks(pre_graph); std::unordered_set post_graph_sources = get_closed_sources(post_graph); - assert(pre_graph_sinks.size() + post_graph_sources.size() == - 1); // assume perfect SP - - Node const &split_point = - get_only(set_union(pre_graph_sinks, post_graph_sources)); + assert(post_graph_sources.size() == 1); // assume perfect SP + Node split_point = get_only(post_graph_sources); + OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { - optional pre_sink_mv = - contains(pre_graph_sinks, split_point) ? make_optional(mv) : nullopt; - optional post_source_mv = - contains(post_graph_sources, split_point) ? make_optional(mv) - : nullopt; + auto new_given_machine_views = merge_maps(given_machine_views, std::unordered_map{{split_point, mv}}); + auto new_frontier_machine_views = merge_maps(frontier_machine_views, + std::unordered_map{{split_edge, mv}}); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( visit(OptimalCost(pre_graph, cost_estimator, resource, - source_machine_view, - pre_sink_mv, + given_machine_views, + new_frontier_machine_views, allowed_machine_views, cached_subgraph_costs), pre_decompn), visit(OptimalCost(post_graph, cost_estimator, resource, - post_source_mv, - sink_machine_view, + new_given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), post_decompn))); @@ -219,23 +197,24 @@ struct OptimalCost { SerialParallelDecomposition decompn1 = decomposed.first; SerialParallelDecomposition decompn2 = decomposed.second; - auto subgraphs = apply_split(g, get_graph_split(decompn1, decompn2)); - SubParallelComputationGraph g1 = subgraphs.first, g2 = subgraphs.second; + GraphSplit graph_split = get_graph_split(decompn1, decompn2); + SubParallelComputationGraphView g1 = get_subgraph(g, graph_split.first), + g2 = get_subgraph(g, graph_split.second); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( visit(OptimalCost(g1, cost_estimator, resource, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn1), visit(OptimalCost(g2, cost_estimator, resource, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn2)); @@ -246,16 +225,16 @@ struct OptimalCost { visit(OptimalCost(g1, cost_estimator, resource_split.first, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn1), visit(OptimalCost(g2, cost_estimator, resource_split.second, - source_machine_view, - sink_machine_view, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn2))); @@ -265,24 +244,17 @@ struct OptimalCost { } OptimalCostResult optimal_cost(Node const &node) const { - if (source_machine_view) { - assert(get_closed_sources(g).empty()); + if (contains_key(given_machine_views, node)) { assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); - MachineMapping mv_map{{{node, source_machine_view.value()}}}; - return {estimate_cost(g, cost_estimator, mv_map), mv_map}; - } else if (sink_machine_view) { - assert(get_closed_sinks(g).empty()); - assert(contains(allowed_machine_views(g.at(node), resource), - sink_machine_view.value())); - MachineMapping mv_map{{{node, sink_machine_view.value()}}}; - return {estimate_cost(g, cost_estimator, mv_map), mv_map}; + MachineMapping mv_map{given_machine_views}; + return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { MachineMapping mv_map{{{node, mv}}}; minimize_runtime(optimal_result, - {estimate_cost(g, cost_estimator, mv_map), mv_map}); + {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}); } return optimal_result; } @@ -300,8 +272,8 @@ OptimalCostResult return visit(OptimalCost(pcg_to_subpcg(g), cost_estimator, resources, - nullopt, - nullopt, + {}, + {}, allowed_machine_views, cached_subgraph_costs), get_serial_parallel_decomposition(g)); diff --git a/lib/compiler/src/old/basic_graph.h b/lib/compiler/src/old/basic_graph.h deleted file mode 100644 index fca575e42a..0000000000 --- a/lib/compiler/src/old/basic_graph.h +++ /dev/null @@ -1,158 +0,0 @@ -#ifndef _BASIC_GRAPH_H -#define _BASIC_GRAPH_H - -#include "utils/hash-utils.h" -#include -#include - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template -struct GraphStructure; -/* -{ - using graph_type = ...; - using node_type = - using tGraph = G; - using tNode = N; - using tEdge = E; - - std::unordered_set get_nodes(G const &) const; - std::unordered_set get_incoming_edges(G const &, N const &) const; - std::unordered_set get_outgoing_edges(G const &, N const &) const; - N get_src(G const &, E const &) const; - N get_dst(G const &, E const &) const; -}; -*/ - -template -struct BasicGraph { - using N = T; - using E = std::pair; - - std::unordered_set nodes; - std::unordered_map> in_edges, out_edges; - - BasicGraph() : BasicGraph({}, {}) {} - - BasicGraph(std::unordered_set const &nodes, std::unordered_set edges) - : nodes(), in_edges(), out_edges() { - this->add_nodes(nodes); - this->add_edges(edges); - } - - void add_edge(N const &src, N const &dst) { - nodes.insert(src); - nodes.insert(dst); - out_edges[src].insert({src, dst}); - in_edges[dst].insert({src, dst}); - } - - void add_edge(E const &e) { - nodes.insert(e.first); - nodes.insert(e.second); - out_edges[e.first].insert(e); - in_edges[e.second].insert(e); - } - - bool has_edge(N const &src, N const &dst) const { - auto iter = this->in_edges.find(dst); - if (iter == this->in_edges.end()) { - return false; - } - - std::unordered_set const &dst_in_edges = iter->second; - return dst_in_edges.find({src, dst}) != dst_in_edges.end(); - } - - bool has_edge(E const &e) const { - return this->has_edge(e.first, e.second); - } - - void remove_edge(N const &src, N const &dst) { - out_edges[src].erase({src, dst}); - in_edges[dst].erase({src, dst}); - } - - void remove_edge(E const &e) { - out_edges[e.first].erase(e); - in_edges[e.second].erase(e); - } - - void add_node(N const &n) { - nodes.insert(n); - } - - template > - void add_nodes(Container const &nodes) { - for (auto const &n : nodes) { - this->add_node(n); - } - } - - template > - void add_edges(Container const &edges) { - for (auto const &e : edges) { - this->add_edge(e); - } - } - - bool operator==(BasicGraph const &other) const { - return this->nodes == other.nodes && this->in_edges == other.in_edges && - this->out_edges == other.out_edges; - } -}; - -template -struct GraphStructure> { - using graph_type = BasicGraph; - using vertex_type = T; - using edge_type = std::pair; - - std::unordered_set get_nodes(graph_type const &g) const { - std::unordered_set nodes(g.nodes); - return nodes; - } - - std::unordered_set get_incoming_edges(graph_type const &g, - vertex_type const &n) const { - std::unordered_set edges; - if (g.in_edges.find(n) != g.in_edges.end()) { - edges.insert(g.in_edges.at(n).begin(), g.in_edges.at(n).end()); - } - return edges; - } - - std::unordered_set get_outgoing_edges(graph_type const &g, - vertex_type const &n) const { - std::unordered_set edges; - if (g.out_edges.find(n) != g.out_edges.end()) { - edges.insert(g.out_edges.at(n).begin(), g.out_edges.at(n).end()); - } - return edges; - } - - vertex_type get_src(graph_type const &g, edge_type const &e) const { - return e.first; - } - - vertex_type get_dst(graph_type const &g, edge_type const &e) const { - return e.second; - } - - void set_src(graph_type const &g, edge_type &e, vertex_type const &n) const { - e.first = n; - } - - void set_dst(graph_type const &g, edge_type &e, vertex_type const &n) const { - e.second = n; - } -}; - -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -#endif // _BASIC_GRAPH_H diff --git a/lib/compiler/src/old/dominators.h b/lib/compiler/src/old/dominators.h deleted file mode 100644 index 70449ee001..0000000000 --- a/lib/compiler/src/old/dominators.h +++ /dev/null @@ -1,494 +0,0 @@ -#ifndef _DOMINATORS_H -#define _DOMINATORS_H - -#include "basic_graph.h" -#include "graph_structures.h" -#include "tl/optional.hpp" -#include "utils/dot_file.h" -#include "utils/record_formatter.h" -#include -#include -#include -#include - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template > -std::unordered_set nodes(G const &g) { - Structure s; - - return s.get_nodes(g); -} - -template > -bool has_edge(G const &g, - typename Structure::vertex_type const &src, - typename Structure::vertex_type const &dst) { - Structure s; - - for (auto const &e : s.get_outgoing_edges(g, src)) { - if (s.get_dst(g, e) == dst) { - return true; - } - } - - return false; -} - -template > -std::unordered_set - outgoing_edges(G const &g, typename Structure::vertex_type const &n) { - Structure s; - return s.get_outgoing_edges(g, n); -} - -template > -std::pair - get_basic_edge(G const &g, typename Structure::edge_type const &e) { - Structure s; - - return {s.get_src(g, e), s.get_dst(g, e)}; -} - -template > -std::vector get_edges(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - - std::vector edges; - - for (N const &n : s.get_nodes(g)) { - for (E const &e : s.get_outgoing_edges(g, n)) { - edges.push_back(e); - } - } - - return edges; -} - -template > -void successors(G const &g, - typename Structure::vertex_type const &node, - std::unordered_set *succ) { - Structure s; - for (auto const &edge : s.get_outgoing_edges(g, node)) { - succ->insert(s.get_dst(g, edge)); - } -} - -template > -std::unordered_set - successors(G const &g, typename Structure::vertex_type const &node) { - // using N = typename Structure::vertex_type; - - std::unordered_set succ; - successors(g, node, &succ); - - return succ; -} - -template > -tl::optional - successor(G const &g, typename Structure::vertex_type const &node) { - auto succs = successors(g, node); - if (succs.size() == 1) { - return *succs.begin(); - } else { - return tl::nullopt; - } -} - -template > -void predecessors(G const &g, - typename Structure::vertex_type const &node, - std::unordered_set *pred) { - Structure s; - for (auto const &edge : s.get_incoming_edges(g, node)) { - pred->insert(s.get_src(g, edge)); - } -} - -template > -std::unordered_set - predecessors(G const &g, typename Structure::vertex_type const &node) { - // using N = typename Structure::vertex_type; - - std::unordered_set pred; - predecessors(g, node, &pred); - - return pred; -} - -template > -tl::optional - predecessor(G const &g, typename Structure::vertex_type const &node) { - auto preds = predecessors(g, node); - if (preds.size() == 1) { - return *preds.begin(); - } else { - return tl::nullopt; - } -} - -template > -std::unordered_set roots(G const &g) { - using N = typename Structure::vertex_type; - - Structure s; - - std::unordered_set nodes = s.get_nodes(g); - std::unordered_set roots; - for (auto const &node : nodes) { - if (s.get_incoming_edges(g, node).empty()) { - roots.insert(node); - } - } - - return roots; -} - -template > -std::unordered_set leaves(G const &g) { - return roots>(g); -} - -template > -void topo_sort(G const &g, - std::vector *ordering) { - using N = typename Structure::vertex_type; - - Structure s; - std::unordered_map> predecessors; - - std::queue q; - for (auto const &node : s.get_nodes(g)) { - predecessors[node]; - for (auto const &edge : s.get_incoming_edges(g, node)) { - predecessors.at(node).insert(s.get_src(g, edge)); - } - } - - for (auto it = predecessors.begin(); it != predecessors.end();) { - if (it->second.empty()) { - q.push(it->first); - it = predecessors.erase(it); - } else { - it++; - } - } - - std::unordered_set node_successors; - while (!q.empty()) { - N const ¤t = q.front(); - - ordering->push_back(current); - - node_successors.clear(); - successors(g, current, &node_successors); - for (auto const &succ : node_successors) { - if (predecessors.find(succ) != predecessors.end()) { - predecessors.at(succ).erase(current); - if (predecessors.at(succ).empty()) { - predecessors.erase(succ); - q.push(succ); - } - } - } - - q.pop(); - } -} - -template > -std::unordered_map> - dominators(G const &g) { - using N = typename Structure::vertex_type; - // using E = typename Structure::edge_type; - - // Structure s; - - std::vector nodes; - topo_sort(g, &nodes); - std::unordered_map> dom; - - std::unordered_set pred_part; - for (auto const &node : nodes) { - pred_part.clear(); - predecessors(g, node, &pred_part); - for (auto const &p : pred_part) { - if (dom.find(node) == dom.end()) { - dom[node] = dom.at(p); - } else { - auto &node_dom_set = dom.at(node); - auto const &p_dom_set = dom.at(p); - for (auto it = node_dom_set.begin(); it != node_dom_set.end();) { - if (p_dom_set.find(*it) == p_dom_set.end()) { - it = node_dom_set.erase(it); - } else { - it++; - } - } - } - } - dom[node].insert(node); - } - - return dom; -} - -template > -std::unordered_map> - post_dominators(G const &g) { - return dominators>(g); -} - -template > -std::unordered_map - imm_dominators(G const &g) { - using N = typename Structure::vertex_type; - // using E = typename Structure::edge_type; - - std::vector topo; - topo_sort(g, &topo); - std::unordered_map topo_rank; - for (int i = 0; i < (int)topo.size(); i++) { - topo_rank[topo[i]] = i; - } - std::unordered_map> dom = - dominators(g); - - std::unordered_map imm_dom; - for (auto const &kv : dom) { - N const &n = kv.first; - std::unordered_set const &n_doms = kv.second; - - // if a node is only dominated by itself, set the dominator to itself to - // signify that it has no immediate dominator - if (n_doms.size() == 1) { - imm_dom[n] = n; - continue; - } - - N const *n_imm_dom = nullptr; - int current_topo_rank = std::numeric_limits::min(); - for (auto const &d : n_doms) { - if (topo_rank.at(d) > current_topo_rank && d != n) { - n_imm_dom = &d; - current_topo_rank = topo_rank.at(d); - } - } - imm_dom[n] = *n_imm_dom; - } - - return imm_dom; -} - -template > -void dfs(G const &g, - typename Structure::vertex_type const &n, - std::function const - &visitor) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - - /* auto i_visitor = std::bind(visitor, g, s, n); */ - auto i_visitor = [&](N const &nn) { return visitor(g, s, n, nn); }; - - std::queue q; - std::unordered_set visited; - - auto visit = [&](N const &n) { - if (visited.find(n) == visited.end()) { - q.push(n); - visited.insert(n); - } - }; - - visit(n); - - while (!q.empty()) { - N current = q.front(); - q.pop(); - - i_visitor(current); - - for (E const &edge : s.get_outgoing_edges(g, current)) { - N const &dst = s.get_dst(g, edge); - visit(dst); - } - } - - return; -} - -template > -std::unordered_set - descendants(G const &g, typename Structure::vertex_type const &n) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - std::unordered_set descendants; - - auto dfs_visitor = [&](G const &gg, - Structure const &ss, - N const &dfs_src, - N const ¤t_node) { - descendants.insert(current_node); - }; - - dfs(g, n, dfs_visitor); - - return descendants; -} - -template > -std::vector> - weakly_connected_components(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - std::vector> result; - std::unordered_set seen; - - for (N const &n : nodes>(g)) { - if (seen.find(n) != seen.end()) { - continue; - } - - auto component = descendants>(g, n); - seen.insert(component.begin(), component.end()); - result.emplace_back(component); - } - - return result; -} - -template > -std::unordered_map - imm_post_dominators(G const &g) { - return imm_dominators>(g); -} - -template > -BasicGraph transitive_reduction(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - BasicGraph reduction; - - std::unordered_set nodes = s.get_nodes(g); - - reduction.add_nodes(nodes); - - std::unordered_set> to_delete; - - auto dfs_visitor = [&](N const &src, - G const &gg, - Structure const &ss, - N const &dfs_src, - N const &nn) { - if (nn != dfs_src && to_delete.find({src, nn}) == to_delete.end() && - has_edge(gg, src, nn)) { - to_delete.insert({src, nn}); - } - }; - - for (N const &n : nodes) { - /* auto n_dfs_visitor = std::bind(dfs_visitor, n); */ - auto n_dfs_visitor = - [&](G const &gg, Structure const &ss, N const &dfs_src, N const &nn) { - return dfs_visitor(n, gg, ss, dfs_src, nn); - }; - - for (N const &child : successors(g, n)) { - dfs(g, child, n_dfs_visitor); - } - } - - for (E const &e : get_edges(g)) { - std::pair basic_edge = get_basic_edge(g, e); - - if (to_delete.find(basic_edge) == to_delete.end()) { - reduction.add_edge(basic_edge); - } - } - - return reduction; -} - -template -void inplace_transitive_reduction(BasicGraph &g) { - using Structure = GraphStructure>; - using G = BasicGraph; - using E = std::pair; - - std::unordered_set to_delete; - - auto dfs_visitor = [&](N const &src, - G const &gg, - Structure const &ss, - N const &dfs_src, - N const &nn) { - if (nn != dfs_src && to_delete.find({src, nn}) == to_delete.end() && - has_edge(gg, src, nn)) { - to_delete.insert({src, nn}); - } - }; - - for (N const &n : g.nodes) { - auto n_dfs_visitor = - [&](G const &gg, Structure const &ss, N const &dfs_src, N const &nn) { - return dfs_visitor(n, gg, ss, dfs_src, nn); - }; - - for (N const &child : successors(g, n)) { - dfs(g, child, n_dfs_visitor); - } - } - - for (E const &e : to_delete) { - g.remove_edge(e); - } -}; - -template > -void export_as_dot( - DotFile &dotfile, - G const &g, - std::function const - &pretty) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - GraphStructure s; - - for (N const &n : s.get_nodes(g)) { - dotfile.add_record_node(n, pretty(n)); - - for (E const &edge : s.get_incoming_edges(g, n)) { - dotfile.add_edge(s.get_src(g, edge), s.get_dst(g, edge)); - } - } - - dotfile.close(); -} - -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -#endif // _DOMINATORS_H diff --git a/lib/compiler/src/old/graph.cc b/lib/compiler/src/old/graph.cc deleted file mode 100644 index 191b1028b7..0000000000 --- a/lib/compiler/src/old/graph.cc +++ /dev/null @@ -1,1255 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "graph.h" -#include "dominators.h" -#include "op-attrs/op-attrs.h" -#include "utils/disjoint_set.h" -#include "utils/unique.h" -#include - -// using FlexFlow::utils::Node; -// using FlexFlow::opmeta::OperatorParameters; - -namespace FlexFlow { - -ParallelComputationGraph::Graph(std::string const &logger_name) - : Graph(spdlog::get(logger_name)) {} - -ParallelComputationGraph::Graph(std::shared_ptr const &logger) - : logger(logger) {} - -Graph::Graph(utils::AdjacencyMultiDiGraph const &g, - utils::bidict const &nodeMap, - std::shared_ptr const &logger) - : g(g), nodeMap(nodeMap), logger(logger) {} - -/* using namespace Legion; */ -/* using FlexFlow::MachineView; */ - -/* LegionRuntime::Logger::Category log_graph("graph"); */ -/* LegionRuntime::Logger::Category log_simplify("graph_simplify"); */ - -void Graph::add_edge(Node const &srcOp, - Node const &dstOp, - int srcIdx, - int dstIdx) { - this->g.add_edge({srcOp, dstOp, (std::size_t)srcIdx, (std::size_t)dstIdx}); -} - -Node Graph::add_node(PCGOperatorAttrs const ¶ms) { - Node n = this->g.add_node(); - this->nodeMap.equate(n, params); - return n; -} - -void Graph::add_edge(utils::MultiDiEdge const &e) { - this->g.add_edge(e); -} - -void Graph::remove_edge(utils::MultiDiEdge const &e, - bool remove_node_if_unused) { - this->g.remove_edge(e); - utils::remove_node_if_unused(this->g, e.src); - utils::remove_node_if_unused(this->g, e.dst); -} - -bool Graph::has_edge(utils::MultiDiEdge const &e) const { - return utils::contains_edge(this->g, e); -} - -void Graph::print_dot() const { - this->print_dot(std::cout); -} - -void Graph::print_dot(std::ostream &s) const { - auto directed = unsafe_view_as_digraph(this->g); - - DotFile dot(s); - - export_as_dot(dot, directed, [&](utils::Node const &node) -> RecordFormatter { - RecordFormatter rf; - rf << node.to_string(); - tl::optional sub_rf = as_dot(this->nodeMap.at_l(node)); - if (sub_rf.has_value()) { - rf << sub_rf.value(); - } - - return rf; - }); - s << std::endl; -} - -bool Graph::has_loop() { - return !utils::is_acyclic(this->g).value_or(true); -} - -/* Node Graph::find_bottleneck_node(Node const &sink_node, */ -/* Node const &source_node) const { */ -/* using FlexFlow::PCG::Utils::GraphStructure; */ -/* using FlexFlow::PCG::Utils::imm_post_dominators; */ -/* using FlexFlow::PCG::Utils::MultisourceGraphStructure; */ -/* using FlexFlow::PCG::Utils::roots; */ - -/* Node source(source_node); */ -/* std::unordered_map ipd; */ -/* std::unordered_set graph_roots = roots(*this); */ -/* if (source_node != Node::INVALID_NODE) { */ -/* ipd = imm_post_dominators(*this); */ -/* } else if (graph_roots.size() == 1) { */ -/* ipd = imm_post_dominators(*this); */ -/* source = *graph_roots.begin(); */ -/* } else { */ -/* ipd = imm_post_dominators>(*this); */ -/* } */ - -/* Node bn_node = ipd.at(source); */ -/* if (bn_node == source || bn_node == sink_node) { */ -/* return Node::INVALID_NODE; */ -/* } */ - -/* return bn_node; */ -/* } */ - -Graph Graph::subgraph(std::unordered_set const &nodes) const { - AdjacencyMultiDiGraph sub_g = subgraph(this->g, nodes); - - bidict sub_nodeMap; - for (auto const &kv : this->nodeMap) { - if (contains(nodes, kv.first)) { - sub_nodeMap.equate(kv.first, kv.second); - } - } - - return {sub_g, sub_nodeMap, this->logger}; -} - -void Graph::remove_node(Node const &node, bool purge_edges) { - assert(purge_edges == true); - utils::remove_node(this->g, node); - this->nodeMap.erase_l(node); -} - -/*static*/ -Graph Graph::singleton(PCGOperatorAttrs const ¶ms) { - Graph g; - g.add_node(params); - return g; -} - -bool Graph::empty() const { - return utils::empty(this->g); -} - -void Graph::replace_subgraph(std::unordered_set const ¤tNodes, - Graph const &replaceWith) { - assert(currentNodes.size() > 0); - if (replaceWith.empty()) { - Graph subgraph = this->subgraph(currentNodes); - assert(!subgraph.empty()); - Node source_node = subgraph.find_source_node(); - Node noop = - this->model->get_or_create_noop_node(source_node.ptr->inputs[0]); - this->replace_subgraph_with_nonempty(currentNodes, - Graph::singleton(this->model, noop)); - this->contract_out_node(noop); - } else { - this->replace_subgraph_with_nonempty(currentNodes, replaceWith); - } -} - -void Graph::replace_subgraph_with_nonempty( - std::unordered_set const ¤tNodes, Graph const &replaceWith) { - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::nodes; - - Node new_sink_node = replaceWith.find_sink_node(); - - Graph old_subgraph = this->subgraph(currentNodes); - Node old_sink_node = old_subgraph.find_sink_node(); - Node old_source_node = old_subgraph.find_source_node(); - - std::unordered_set all_nodes = nodes(*this); - - for (Edge const &old_inner_edge : get_edges(old_subgraph)) { - this->remove_edge(old_inner_edge, false); - } - for (Edge const &new_inner_edge : get_edges(replaceWith)) { - this->add_edge(new_inner_edge); - } - - std::unordered_set old_in_edges = this->inEdges.at(old_source_node); - if (!old_in_edges.empty()) { - Node new_source_node = replaceWith.find_source_node(); - for (Edge const &old_in_edge : old_in_edges) { - Edge new_in_edge(old_in_edge); - new_in_edge.dstOp = new_source_node; - this->remove_edge(old_in_edge, false); - this->add_edge(new_in_edge); - } - } - - std::unordered_set old_out_edges = this->outEdges.at(old_sink_node); - for (Edge const &old_out_edge : old_out_edges) { - Edge new_out_edge(old_out_edge); - new_out_edge.srcOp = new_sink_node; - this->remove_edge(old_out_edge, false); - this->add_edge(new_out_edge); - } - - for (Node const &node : currentNodes) { - this->remove_node(node); - } - - assert(this->check_correctness()); -} - -void Graph::contract_out_node(Node const &node) { - contract_node(this->g, node); - this->nodeMap.erase_l(node); -} - -/* std::pair, std::unique_ptr> */ -/* Graph::split_at_node(Node const &bottleneck) const { */ -/* using FlexFlow::PCGe:Utils::topo_sort; */ - -/* auto first_graph = std::unique_ptr(new Graph(this->model)); */ -/* auto second_graph = std::unique_ptr(new Graph(this->model)); */ - -/* std::unordered_set used_nodes; */ -/* { */ -/* std::vector topo_sorted; */ -/* topo_sort(*this, &topo_sorted); */ - -/* for (auto const &node : topo_sorted) { */ -/* if (node == bottleneck) { */ -/* break; */ -/* } */ - -/* used_nodes.insert(node); */ -/* } */ -/* used_nodes.insert(bottleneck); */ - -/* assert(used_nodes.size() < topo_sorted.size()); */ -/* } */ - -/* for (auto const &it : this->inEdges) { */ -/* auto const &inList = it.second; */ -/* if (used_nodes.find(it.first) != used_nodes.end()) { */ -/* // Add all in-edges of used_nodes in to the first_graph */ -/* for (auto const &it2 : inList) { */ -/* first_graph->add_edge(it2); */ -/* } */ -/* } else { */ -/* // Add all in-edges of not_used_nodes into the second_graph */ -/* for (auto const &it2 : inList) { */ -/* second_graph->add_edge(it2); */ -/* } */ -/* } */ -/* } */ - -/* return {std::move(first_graph), std::move(second_graph)}; */ -/* } */ - -void Graph::remove_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - for (auto const &n : nodes(*this)) { - if (n.ptr->op_type == OP_INPUT) { - this->remove_node(n, true /*purge_edges*/); - } - } -} - -Node Graph::clone_node(Node const &n) { - Node cloned = n; - cloned.original_guid = n.guid; - cloned.guid = this->model->node_global_guid++; - this->add_node(cloned); - return cloned; -} - -Node Graph::declone_node(Node const &n) { - assert(n.original_guid.has_value()); - Node decloned = n; - decloned.guid = n.original_guid.value(); - decloned.original_guid = tl::nullopt; - this->add_node(decloned); - return decloned; -} - -std::pair> - Graph::deduplicate_input_node(Node const &n) { - using FlexFlow::PCG::Utils::nodes; - using FlexFlow::PCG::Utils::outgoing_edges; - - assert(n.original_guid.has_value()); - std::unordered_set old_all_nodes = nodes(*this); - Node decloned = this->declone_node(n); - - std::unordered_set old_nodes; - std::unordered_set new_edges; - for (Node const &nn : old_all_nodes) { - if (nn.original_guid == n.original_guid) { - old_nodes.insert(nn); - for (Edge const &e : outgoing_edges(*this, nn)) { - Edge decloned_edge(e); - decloned_edge.replace_node(nn, decloned); - new_edges.insert(decloned_edge); - } - this->remove_node(nn, true /*purge_edges*/); - } - } - - for (Edge const &e : new_edges) { - this->add_edge(e); - } - - return {decloned, old_nodes}; -} - -std::unordered_map Graph::deduplicate_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - std::unordered_map deduplication_map; - - bool done; - while (true) { - done = true; - for (Node const &n : nodes(*this)) { - if (n.original_guid.has_value()) { - done = false; - auto kv = this->deduplicate_input_node(n); - for (auto const &r : kv.second) { - deduplication_map[r] = kv.first; - } - break; - } - } - if (done) { - break; - } - } - - return deduplication_map; -} - -void Graph::duplicate_input_node(Node const &n) { - using FlexFlow::PCG::Utils::outgoing_edges; - using FlexFlow::PCG::Utils::successors; - - assert(n.ptr->op_type == OP_INPUT); - - std::unordered_map clones; - - for (auto const &s : successors(*this, n)) { - clones[s] = this->clone_node(n); - } - - for (auto const &e : outgoing_edges(*this, n)) { - Edge cloned(e); - cloned.srcOp = clones.at(e.dstOp); - this->add_edge(cloned); - } - this->remove_node(n, true /*purge_edges*/); -} - -void Graph::duplicate_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - for (auto const &n : nodes(*this)) { - if (n.ptr->op_type == OP_INPUT) { - this->duplicate_input_node(n); - } - } -} - -std::pair, std::unique_ptr> - Graph::split_horizontal(Node const &source_node, - Node const &sink_node) const { - using FlexFlow::PCG::Utils::weakly_connected_components; - - Graph trimmed_graph(*this); - assert(sink_node != - Node::INVALID_NODE); // sink node should never be invalid node - if (source_node != Node::INVALID_NODE) { - trimmed_graph.remove_node(source_node, true /*purge_edges*/); - } - trimmed_graph.remove_node(sink_node, true /*purge_edges*/); - std::vector> wccs = - weakly_connected_components(trimmed_graph); - assert(wccs.size() >= 2); - std::unordered_set first_branch = wccs.back(); - wccs.pop_back(); - std::unordered_set rest; - for (auto const &wcc : wccs) { - rest.insert(wcc.begin(), wcc.end()); - } - if (source_node != Node::INVALID_NODE) { - first_branch.insert(source_node); - rest.insert(source_node); - } - first_branch.insert(sink_node); - rest.insert(sink_node); - - auto first_graph = - std::unique_ptr(new Graph(this->subgraph(first_branch))); - auto second_graph = std::unique_ptr(new Graph(this->subgraph(rest))); - - return {std::move(first_graph), std::move(second_graph)}; -} - -GraphCostResult GraphCostResult::invalid() { - return {std::numeric_limits::infinity(), {}}; -} - -bool GraphCostResult::operator<(GraphCostResult const &other) const { - return this->cost < other.cost; -} - -std::ostream &operator<<(std::ostream &s, GraphCostResult const &r) { - s << "GraphCostResult{cost=" << r.cost << "}"; - return s; -} - -std::ostream &operator<<(std::ostream &s, GraphOptimizeResult const &r) { - s << "GraphOptimizeResult{cost=" << r.cost << "}"; - return s; -} - -template <> -GraphCostResult sequence_cost(GraphCostResult const &first, - GraphCostResult const &second) { - GraphCostResult result(first); - result.cost += second.cost; - result.views.insert(second.views.cbegin(), second.views.cend()); - return result; -} - -template <> -float sequence_cost(float const &first, float const &second) { - return first + second; -} - -template <> -GraphOptimizeResult - sequence_cost(GraphOptimizeResult const &first, - GraphOptimizeResult const &second) { - GraphOptimizeResult result; - result.cost = first.cost + second.cost; - result.views.insert(first.views.cbegin(), first.views.cend()); - result.views.insert(second.views.cbegin(), second.views.cend()); - - result.graph = second.graph; - Node second_src = result.graph.value().find_source_node(); - result.graph.value().replace_subgraph({second_src}, first.graph.value()); - return result; -} - -template <> -GraphCostResult parallel_cost(GraphCostResult const &first, - GraphCostResult const &second) { - GraphCostResult result; - result.cost = std::max(first.cost, second.cost); - result.views.insert(first.views.cbegin(), first.views.cend()); - result.views.insert(second.views.cbegin(), second.views.cend()); - - return result; -} - -template <> -float parallel_cost(float const &first, float const &second) { - return std::max(first, second); -} - -float Graph::optimal_cost() const { - return this->generic_optimal_cost(); -} - -std::unordered_map Graph::optimal_views() const { - return this->generic_optimal_cost().views; -} - -Graph Graph::reduced() const { - using FlexFlow::PCG::Utils::BasicGraph; - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::transitive_reduction; - - BasicGraph transitive_skeleton = transitive_reduction(*this); - - Graph reduced_graph(this->model); - - for (Edge const &e : get_edges(*this)) { - if (transitive_skeleton.has_edge(e.srcOp, e.dstOp)) { - reduced_graph.add_edge(e); - } - } - - return reduced_graph; -} - -/** - * @brief A generic cost function for a graph capable of finding both the cost - * and the optimal views - * - * @note A templated function is used here because while the caching behaviors - * of the cost and the optimal views are different, much of the code between the - * two versions is almost identical. By using a few template specializations we - * can avoid duplicating all this code. - * - * @tparam T the result type (can be either float or GraphCostResult) - * @return T the cost of the graph (along with any additional data in the return - * type) - */ -template -T Graph::generic_optimal_cost() const { - using FlexFlow::PCG::Utils::GraphStructure; - - Graph reduced_graph = this->reduced(); - // GraphStructure s; - // if (source_node.ptr->op_type == OP_INPUT) { - // for (auto const &e : s.get_outgoing_edges(reduced_graph, source_node)) { - // reduced_graph.remove_edge(e, false/*remove_node_if_unused*/); - // } - // reduced_graph.remove_node(source_node); - // } - - Node sink_node = reduced_graph.find_sink_node(); - this->search->logger->info() << "Found sink node: " << sink_node.to_string(); - - MachineResource resource(model->config); - - std::vector valid_views = - search->get_valid_machine_views(sink_node, resource, true); - - T optimal = search->infinity(); - - this->search->logger->info() - << "Exploring " << valid_views.size() << " valid views"; - for (MachineView const &sink_view : valid_views) { - this->search->logger->info() << " Exploring valid view " << sink_view; - T new_cost = - search->graph_cost(&reduced_graph, - {Node::INVALID_NODE, MachineView::NO_VIEW}, - {sink_node, sink_view}, - resource, - true); - if (new_cost < optimal) { - optimal = new_cost; - } - } - - return optimal; -} - -size_t Graph::hash(void) const { - // Graph hash should be additive and independent to the ordering of the nodes - size_t total_hash = 0; - for (auto const &it : inEdges) { - auto const &inList = it.second; - size_t node_hash = std::hash()((size_t)it.first.ptr); - for (auto const &e : inList) { - size_t edge_hash = 17; - edge_hash = edge_hash * 31 + std::hash()((size_t)e.srcOp.ptr); - edge_hash = edge_hash * 31 + std::hash()(e.srcIdx); - edge_hash = edge_hash * 31 + std::hash()(e.dstIdx); - node_hash *= edge_hash; - } - total_hash += node_hash; - } - return total_hash; -} - -size_t dp_state_hash(Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resource) { - size_t key = graph->hash(); - hash_combine(key, sink_node.ptr); - hash_combine(key, sink_view.hash()); - hash_combine(key, source_node.ptr); - hash_combine(key, resource.hash()); - return key; -} - -GraphOptimalViewSerialized - Graph::graph_optimize_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - FFModel *model = *((FFModel **)task->args); - if (model->config.search_num_nodes.has_value()) { - model->config.numNodes = model->config.search_num_nodes.value(); - } - if (model->config.search_num_workers.has_value()) { - model->config.workersPerNode = model->config.search_num_workers.value(); - } - model->all_valid_views.clear(); - model->register_all_machine_views(model->config.numNodes, - model->config.workersPerNode, - model->config.cpusPerNode, - model->all_valid_views); - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - MachineModel *machine; - if (model->config.machine_model_version == 0) { - machine = - (MachineModel *)new SimpleMachineModel(model->config.numNodes, - model->config.workersPerNode, - gpu_mem.capacity()); - } else if (model->config.machine_model_version == 1 and - !model->config.machine_model_file.empty()) { - machine = (MachineModel *)new EnhancedMachineModel( - model->config.machine_model_file, gpu_mem.capacity()); - } else { - assert(false && - "machine model creation error: currently only support " - "machine-model-version = 0 or 1. When machine-model-version = 1, " - "machine-model-file should not be empty."); - } - model->simulator = - make_unique(model, model->handlers[0], gpu_mem, machine); - std::unique_ptr best_graph; - std::unordered_map optimal_views; - if (model->config.only_data_parallel) { - Graph *graph = new Graph(model); - std::unordered_map op_to_node_map; - for (FlexFlow::Op const *dstOp : model->operators) { - Node dstNode; - dstNode.ptr = dstOp; - dstNode.guid = model->node_global_guid++; - op_to_node_map[dstOp] = dstNode; - for (int j = 0; j < dstOp->numInputs; j++) { - FlexFlow::Op const *srcOp = dstOp->inputs[j]->owner_op; - assert(op_to_node_map.find(srcOp) != op_to_node_map.end()); - Node srcNode = op_to_node_map[srcOp]; - graph->add_edge(srcNode, dstNode, dstOp->inputs[j]->owner_idx, j); - } - } - best_graph = std::unique_ptr(graph); - MachineView data_parallel_view; - data_parallel_view.device_type = MachineView::GPU; - data_parallel_view.ndims = 1; - data_parallel_view.dim[0] = - model->config.numNodes * model->config.workersPerNode; - data_parallel_view.stride[0] = 1; - data_parallel_view.start_device_id = 0; - for (auto const &node : best_graph->inEdges) { - optimal_views[node.first] = data_parallel_view; - } - } else { - model->graph_optimize(model->config.search_budget, - model->config.only_data_parallel, - best_graph, - optimal_views); - } - /* Serializer sez; */ - /* // First serialize graph */ - /* sez.serialize(best_graph->inEdges.size()); */ - /* std::unordered_map todos; */ - /* std::vector opList; */ - /* for (auto const &it : best_graph->inEdges) { */ - /* auto const &inList = it.second; */ - /* todos[it.first] = (int)inList.size(); */ - /* if (todos[it.first] == 0) { */ - /* opList.push_back(it.first); */ - /* } */ - /* } */ - /* size_t node_idx = 0; */ - /* while (node_idx < opList.size()) { */ - /* Node cur_node = opList[node_idx++]; */ - /* auto const &outList = best_graph->outEdges[cur_node]; */ - /* for (auto const &e : outList) { */ - /* todos[e.dstOp]--; */ - /* if (todos[e.dstOp] == 0) { */ - /* opList.push_back(e.dstOp); */ - /* } */ - /* } */ - /* auto const &inList = best_graph->inEdges[cur_node]; */ - /* sez.serialize(inList.size()); */ - /* for (auto const &e : inList) { */ - /* sez.serialize(e.srcOp.guid); */ - /* assert(e.dstOp.guid == cur_node.guid); */ - /* sez.serialize(e.srcIdx); */ - /* sez.serialize(e.dstIdx); */ - /* } */ - /* sez.serialize((size_t)10101010); // safe guard for the end of inedges */ - /* Op const *op = cur_node.ptr; */ - /* assert(op != NULL); */ - /* sez.serialize(cur_node.guid); */ - /* sez.serialize(op->op_type); */ - /* switch (op->op_type) { */ - /* case OP_INPUT: { */ - /* assert(op->numOutputs == 1); */ - /* NoOp *noop = (NoOp *)op; */ - /* sez.serialize(noop->op_type); */ - /* sez.serialize(noop->input_tensor_guid); */ - /* sez.serialize(noop->outputs[0]->data_type); */ - /* sez.serialize(noop->outputs[0]->num_dims); */ - /* for (int i = 0; i < noop->outputs[0]->num_dims; i++) { */ - /* sez.serialize(noop->outputs[0]->dims[i]); */ - /* } */ - /* break; */ - /* } */ - /* case OP_NOOP: { */ - /* break; */ - /* } */ - /* case OP_CONCAT: { */ - /* Concat *concat = (Concat *)op; */ - /* sez.serialize(concat->legion_axis); */ - /* break; */ - /* } */ - /* case OP_SPLIT: { */ - /* Split *split = (Split *)op; */ - /* sez.serialize(split->legion_axis); */ - /* sez.serialize(split->numOutputs); */ - /* for (int i = 0; i < split->numOutputs; i++) { */ - /* sez.serialize(split->outputs[i]->dims[split->legion_axis].size); */ - /* } */ - /* break; */ - /* } */ - /* case OP_EMBEDDING: { */ - /* Embedding *embed = (Embedding *)op; */ - /* sez.serialize(embed->layer_guid.id); */ - /* sez.serialize(embed->num_entries); */ - /* sez.serialize(embed->out_channels); */ - /* sez.serialize(embed->aggr); */ - /* sez.serialize(embed->data_type); */ - /* break; */ - /* } */ - /* case OP_EW_ADD: */ - /* case OP_EW_SUB: */ - /* case OP_EW_MUL: */ - /* case OP_EW_MAX: */ - /* case OP_EW_MIN: { */ - /* sez.serialize(op->op_type); */ - /* break; */ - /* } */ - /* case OP_MULTIHEAD_ATTENTION: { */ - /* MultiHeadAttention *attn = (MultiHeadAttention *)op; */ - /* sez.serialize(attn->layer_guid.id); */ - /* sez.serialize(attn->oProjSize); */ - /* sez.serialize(attn->num_heads); */ - /* sez.serialize(attn->qProjSize); */ - /* sez.serialize(attn->vProjSize); */ - /* sez.serialize(attn->dropout); */ - /* sez.serialize(attn->bias); */ - /* sez.serialize(attn->add_bias_kv); */ - /* sez.serialize(attn->add_zero_attn); */ - /* break; */ - /* } */ - /* case OP_SOFTMAX: { */ - /* Softmax *softmax = (Softmax *)op; */ - /* sez.serialize(softmax->dim); */ - /* break; */ - /* } */ - /* case OP_REPARTITION: { */ - /* Repartition *repart = (Repartition *)op; */ - /* sez.serialize(repart->repartition_dim); */ - /* sez.serialize(repart->repartition_degree); */ - /* break; */ - /* } */ - /* case OP_REPLICATE: { */ - /* Replicate *replicate = (Replicate *)op; */ - /* sez.serialize(replicate->replicate_dim); */ - /* sez.serialize(replicate->replicate_degree); */ - /* break; */ - /* } */ - /* case OP_REDUCTION: { */ - /* Reduction *reduction = (Reduction *)op; */ - /* sez.serialize(reduction->reduction_dim); */ - /* sez.serialize(reduction->reduction_degree); */ - /* break; */ - /* } */ - /* case OP_COMBINE: { */ - /* Combine *combine = (Combine *)op; */ - /* sez.serialize(combine->combine_dim); */ - /* sez.serialize(combine->combine_degree); */ - /* break; */ - /* } */ - /* case OP_FUSED_PARALLEL: { */ - /* FusedParallelOp *fused = (FusedParallelOp *)op; */ - /* sez.serialize(fused->num_parallel_ops); */ - /* for (int i = 0; i < fused->num_parallel_ops; i++) { */ - /* sez.serialize(fused->parallel_ops[i]); */ - /* } */ - /* break; */ - /* } */ - /* default: { */ - /* op->serialize(sez); */ - /* } */ - /* } */ - /* sez.serialize((size_t)12345678); // safe guard for the end of an op */ - /* } */ - /* assert(node_idx == best_graph->inEdges.size()); */ - /* // Second, serialize optimal machine view */ - /* printf("opotimal_views.size = %zu\n", optimal_views.size()); */ - /* sez.serialize(optimal_views.size()); */ - /* for (auto const &it : optimal_views) { */ - /* sez.serialize((size_t)98765432); // safe guard */ - /* sez.serialize(it.first.guid); */ - /* sez.serialize(it.second); */ - /* } */ - /* assert(sez.get_used_bytes() < GraphOptimalViewSerialized::buffer_size); */ - /* GraphOptimalViewSerialized ret; */ - /* ret.total_bytes = sez.get_used_bytes(); */ - /* memcpy(ret.data, sez.get_buffer(), ret.total_bytes); */ - /* // Deallocate best_graph */ - /* // delete best_graph; */ - /* return ret; */ -} - -}; // namespace FlexFlow - -namespace FlexFlow { - -using PCG::Edge; -using PCG::Graph; -using PCG::GraphCostResult; -using PCG::Node; - -void FFModel::register_all_machine_views( - int num_nodes, - int gpus_per_node, - int cpus_per_node, - std::vector &valid_views) { - // Single-parallelism-dimension views - for (int i = 1; i <= num_nodes * gpus_per_node; i++) { - if (num_nodes * gpus_per_node % i == 0) { - MachineView view; - view.device_type = MachineView::GPU; - view.ndims = 1; - view.dim[0] = i; - view.stride[0] = 1; - view.start_device_id = 0; - valid_views.push_back(view); - } - } - // Two-dimensional views - /* for (int i = 1; i <= num_nodes; i++) { */ - /* for (int j = 1; j <= gpus_per_node; j++) { */ - /* MachineView view; */ - /* view.device_type = MachineView::GPU; */ - /* view.ndims = 2; */ - /* view.dim[0] = i; */ - /* view.stride[0] = 1; */ - /* view.dim[1] = j; */ - /* view.stride[1] = 1; */ - /* view.start_device_id = 0; */ - /* valid_views.push_back(view); */ - /* } */ - /* } */ -} - -float FFModel::graph_cost(Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resources, - bool include_sink_compute_time, - bool constructing_optimal_view) { - assert(!graph->inEdges.empty()); - - return this->search->graph_cost(graph, - {source_node, source_view}, - {sink_node, sink_view}, - resources, - include_sink_compute_time); -} - -void FFModel::construct_optimal_view( - Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resources, - bool include_sink_compute_time, - float optimal_cost, - std::unordered_map &optimal_views) { - GraphCostResult result = - this->search->graph_cost(graph, - {source_node, source_view}, - {sink_node, sink_view}, - resources, - include_sink_compute_time); - - optimal_views.insert(result.views.begin(), result.views.end()); -} - -/* void FFModel::deserialize_graph_optimal_view( */ -/* Legion::Deserializer &dez, */ -/* Graph *graph, */ -/* std::unordered_map &optimal_views) { */ -/* // Deserializer dez(serialized.data, serialized.total_bytes); */ -/* std::unordered_map guid_to_nodes; */ -/* size_t num_nodes; */ -/* dez.deserialize(num_nodes); */ -/* // best_graph = new Graph(this); */ -/* for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { */ -/* Edge inedges[MAX_NUM_INPUTS]; */ -/* ParallelTensor inputs[MAX_NUM_INPUTS]; */ -/* size_t num_inputs; */ -/* dez.deserialize(num_inputs); */ -/* for (size_t j = 0; j < num_inputs; j++) { */ -/* size_t src_guid; */ -/* int src_idx, dst_idx; */ -/* dez.deserialize(src_guid); */ -/* assert(guid_to_nodes.find(src_guid) != guid_to_nodes.end()); */ -/* dez.deserialize(src_idx); */ -/* dez.deserialize(dst_idx); */ -/* assert(dst_idx < (int)num_inputs); */ -/* inedges[dst_idx].srcOp = guid_to_nodes[src_guid]; */ -/* inedges[dst_idx].srcIdx = src_idx; */ -/* inedges[dst_idx].dstIdx = dst_idx; */ -/* inputs[dst_idx] = inedges[dst_idx].srcOp.ptr->outputs[src_idx]; */ -/* } */ -/* { */ -/* size_t safecode; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 10101010); */ -/* } */ -/* Node node = Node::INVALID_NODE; */ -/* size_t guid; */ -/* OperatorType op_type; */ -/* dez.deserialize(guid); */ -/* dez.deserialize(op_type); */ -/* switch (op_type) { */ -/* case OP_INPUT: { */ -/* assert(num_inputs == 0); */ -/* int num_dims; */ -/* ParallelDim dims[MAX_TENSOR_DIM]; */ -/* OperatorType op_type; */ -/* dez.deserialize(op_type); */ -/* size_t input_tensor_guid; */ -/* dez.deserialize(input_tensor_guid); */ -/* DataType data_type; */ -/* dez.deserialize(data_type); */ -/* dez.deserialize(num_dims); */ -/* for (int i = 0; i < num_dims; i++) { */ -/* dez.deserialize(dims[i]); */ -/* } */ -/* ParallelTensor t = */ -/* create_parallel_tensor_legion_ordering(num_dims, */ -/* dims, */ -/* data_type, */ -/* nullptr, */ -/* 0, */ -/* true create_grad, */ -/* input_tensor_guid); */ -/* node.ptr = t->owner_op; */ -/* node.guid = node_global_guid++; */ -/* break; */ -/* } */ -/* case OP_NOOP: { */ -/* assert(num_inputs == 1); */ -/* node = get_or_create_noop_node(inputs[0]); */ -/* break; */ -/* } */ -/* case OP_BATCHMATMUL: { */ -/* node = BatchMatmul::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_CAST: { */ -/* node = Cast::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_CONCAT: { */ -/* int legion_axis; */ -/* dez.deserialize(legion_axis); */ -/* node = get_or_create_node( */ -/* {std::begin(inputs), std::begin(inputs) + num_inputs}, */ -/* {legion_axis}); */ -/* break; */ -/* } */ -/* case OP_SPLIT: { */ -/* int legion_axis; */ -/* dez.deserialize(legion_axis); */ -/* int num_outputs; */ -/* dez.deserialize(num_outputs); */ -/* std::vector splits; */ -/* for (int i = 0; i < num_outputs; i++) { */ -/* int dim_size; */ -/* dez.deserialize(dim_size); */ -/* splits.push_back(dim_size); */ -/* } */ -/* node = get_or_create_node(inputs[0], {splits, legion_axis}); - */ -/* break; */ -/* } */ -/* case OP_EMBEDDING: { */ -/* assert(num_inputs == 1); */ -/* AggrMode aggr; */ -/* int num_entries, out_channels; */ -/* size_t id; */ -/* DataType data_type; */ -/* dez.deserialize(id); */ -/* LayerID layer_guid(id); */ -/* dez.deserialize(num_entries); */ -/* dez.deserialize(out_channels); */ -/* dez.deserialize(aggr); */ -/* dez.deserialize(data_type); */ - -/* EmbeddingParams params; */ -/* params.aggr = aggr; */ -/* params.num_entries = num_entries; */ -/* params.out_channels = out_channels; */ -/* params.layer_guid = layer_guid; */ -/* params.data_type = data_type; */ -/* node = get_or_create_node(inputs[0], params); */ -/* break; */ -/* } */ -/* case OP_EW_ADD: */ -/* case OP_EW_SUB: */ -/* case OP_EW_MUL: */ -/* case OP_EW_MAX: */ -/* case OP_EW_MIN: { */ -/* assert(num_inputs == 2); */ -/* OperatorType op_type; */ -/* dez.deserialize(op_type); */ -/* node = get_or_create_node({inputs[0], inputs[1]}, */ -/* {op_type}); */ -/* break; */ -/* } */ -/* case OP_CONV2D: { */ -/* node = Conv2D::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_DROPOUT: { */ -/* node = Dropout::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_EXP: */ -/* case OP_SIN: */ -/* case OP_COS: */ -/* case OP_SCALAR_MULTIPLY: */ -/* case OP_SCALAR_FLOOR_DIV: */ -/* case OP_SCALAR_TRUE_DIV: */ -/* case OP_SCALAR_ADD: */ -/* case OP_SCALAR_SUB: */ -/* case OP_RELU: */ -/* case OP_SIGMOID: */ -/* case OP_TANH: */ -/* case OP_POW: */ -/* case OP_IDENTITY: */ -/* case OP_GELU: */ -/* case OP_ELU: { */ -/* node = ElementUnary::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_FLAT: { */ -/* node = Flat::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_GATHER: { */ -/* node = Gather::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_LAYERNORM: { */ -/* node = LayerNorm::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_LINEAR: { */ -/* node = Linear::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_MULTIHEAD_ATTENTION: { */ -/* assert(num_inputs == 3); */ -/* int embed_dim, num_heads, k_dim, v_dim; */ -/* float dropout; */ -/* bool bias, add_bias_kv, add_zero_attn; */ -/* size_t id; */ -/* dez.deserialize(id); */ -/* LayerID layer_guid(id); */ -/* dez.deserialize(embed_dim); */ -/* dez.deserialize(num_heads); */ -/* dez.deserialize(k_dim); */ -/* dez.deserialize(v_dim); */ -/* dez.deserialize(dropout); */ -/* dez.deserialize(bias); */ -/* dez.deserialize(add_bias_kv); */ -/* dez.deserialize(add_zero_attn); */ - -/* MultiHeadAttentionParams params; */ -/* params.embed_dim = embed_dim; */ -/* params.num_heads = num_heads; */ -/* params.kdim = k_dim; */ -/* params.vdim = v_dim; */ -/* params.dropout = dropout; */ -/* params.bias = bias; */ -/* params.add_bias_kv = add_bias_kv; */ -/* params.add_zero_attn = add_zero_attn; */ -/* params.layer_guid = layer_guid; */ -/* node = get_or_create_node( */ -/* {inputs[0], inputs[1], inputs[2]}, params); */ -/* break; */ -/* } */ -/* case OP_TOPK: { */ -/* node = TopK::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_GROUP_BY: { */ -/* node = Group_by::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_AGGREGATE: { */ -/* // node = Aggregate::deserialize(*this, dez, inputs, num_inputs); */ -/* int n; */ -/* float lambda_bal; */ -/* dez.deserialize(n); */ -/* dez.deserialize(lambda_bal); */ -/* assert(num_inputs == n + 4); */ -/* AggregateParams params; */ -/* params.n = n; */ -/* params.lambda_bal = lambda_bal; */ -/* node = get_or_create_node( */ -/* {std::begin(inputs), std::begin(inputs) + num_inputs}, params); - */ -/* break; */ -/* } */ -/* case OP_POOL2D: { */ -/* node = Pool2D::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_REDUCE_SUM: { */ -/* node = Reduce::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_RESHAPE: { */ -/* node = Reshape::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_SOFTMAX: { */ -/* assert(num_inputs == 1); */ -/* int softmax_dim; */ -/* dez.deserialize(softmax_dim); */ -/* node = get_or_create_node(inputs[0], {softmax_dim}); */ -/* break; */ -/* } */ -/* case OP_TRANSPOSE: { */ -/* node = Transpose::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_COMBINE: { */ -/* assert(num_inputs == 1); */ -/* int combine_dim, combine_degree; */ -/* dez.deserialize(combine_dim); */ -/* dez.deserialize(combine_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {combine_dim, combine_degree}); */ -/* break; */ -/* } */ -/* case OP_REPARTITION: { */ -/* assert(num_inputs == 1); */ -/* int repartition_dim, repartition_degree; */ -/* dez.deserialize(repartition_dim); */ -/* dez.deserialize(repartition_degree); */ -/* node = get_or_create_node( */ -/* inputs[0], {repartition_dim, repartition_degree}); */ -/* break; */ -/* } */ -/* case OP_REPLICATE: { */ -/* assert(num_inputs == 1); */ -/* int replicate_dim, replicate_degree; */ -/* dez.deserialize(replicate_dim); */ -/* dez.deserialize(replicate_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {replicate_dim, - * replicate_degree}); */ -/* break; */ -/* } */ -/* case OP_REDUCTION: { */ -/* assert(num_inputs == 1); */ -/* int reduction_dim, reduction_degree; */ -/* dez.deserialize(reduction_dim); */ -/* dez.deserialize(reduction_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {reduction_dim, - * reduction_degree}); */ -/* break; */ -/* } */ -/* case OP_FUSED_PARALLEL: { */ -/* assert(num_inputs == 1); */ -/* std::vector parallel_ops; */ -/* int num_parallel_ops; */ -/* dez.deserialize(num_parallel_ops); */ -/* for (int i = 0; i < num_parallel_ops; i++) { */ -/* ParallelOpInfo info; */ -/* dez.deserialize(info); */ -/* parallel_ops.push_back(info); */ -/* } */ -/* node = get_or_create_node(inputs[0], - * {parallel_ops}); */ -/* break; */ -/* } */ -/* default: { */ -/* fprintf(stderr, */ -/* "The following operator type is currently not supported" */ -/* " for graph deserialization: %s\n" */ -/* "Report the issue to the FlexFlow developers\n", */ -/* get_operator_type_name(op_type).c_str()); */ -/* assert(false && "Unsupported operator type"); */ -/* } */ -/* } */ -/* { */ -/* size_t safecode; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 12345678); */ -/* } */ -/* assert(node.ptr != nullptr); */ -/* guid_to_nodes[guid] = node; */ -/* for (size_t i = 0; i < num_inputs; i++) { */ -/* inedges[i].dstOp = node; */ -/* graph->add_edge(inedges[i]); */ -/* } */ -/* } */ -/* // Second, deserialize optimal machine view */ -/* size_t num_views; */ -/* dez.deserialize(num_views); */ -/* printf("views.size() = %zu\n", num_views); */ -/* for (size_t i = 0; i < num_views; i++) { */ -/* size_t safecode, guid; */ -/* MachineView view; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 98765432); */ -/* dez.deserialize(guid); */ -/* assert(guid_to_nodes.find(guid) != guid_to_nodes.end()); */ -/* dez.deserialize(view); */ -/* optimal_views[guid_to_nodes[guid]] = view; */ -/* } */ -/* assert(dez.get_remaining_bytes() == 0); */ -/* printf("Deserialized Views...\n"); */ -/* for (auto const &it : optimal_views) { */ -/* printf("node[%zu]: type(%s) view(%d %d %d) ", */ -/* it.first.guid, */ -/* it.first.to_string().c_str(), */ -/* it.second.ndims, */ -/* it.second.dim[0], */ -/* it.second.start_device_id); */ -/* auto const &list = graph->inEdges.at(it.first); */ -/* for (auto const &it2 : list) { */ -/* Edge e = it2; */ -/* printf(" inEdge(node(%zu) idx(%d))", e.srcOp.guid, e.srcIdx); */ -/* } */ -/* printf("\n"); */ -/* } */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/compiler/src/old/graph.h b/lib/compiler/src/old/graph.h deleted file mode 100644 index db313b080d..0000000000 --- a/lib/compiler/src/old/graph.h +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2021 CMU, Facebook, LANL, MIT, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_GRAPH_H_ -#define _FLEXFLOW_GRAPH_H_ -#include "basic_graph.h" -/* #include "node.h" */ -#include "graph_structures.h" -#include "op-attrs/op-attrs.h" -#include "pcg/machine_view.h" -#include "utils/bidict.h" -#include "utils/dot_file.h" -#include "utils/graph.h" -#include "utils/graph/serialparallel.h" -#include "utils/recursive_logger.h" -#include -#include - -// extern LegionRuntime::Logger::Category log_dp; - -/* namespace FlexFlow { */ -/* namespace ffc { */ - -/* class SearchHelper; */ - -/* struct GraphOptimalViewSerialized { */ -/* #ifdef LEGION_MAX_RETURN_SIZE */ -/* static const size_t buffer_size = LEGION_MAX_RETURN_SIZE - 8; */ -/* #else */ -/* static const size_t buffer_size = 1024 * 1024 - 8; */ -/* #endif */ -/* size_t total_bytes; */ -/* char data[buffer_size]; */ -/* }; */ - -/* class Graph { */ -/* public: */ -/* Graph() = default; */ -/* Graph(std::string const &logger_name); */ -/* Graph(std::shared_ptr const &logger); */ - -/* void add_edge(utils::Node const &srcOp, utils::Node const &dstOp, int - * srcIdx, int dstIdx); */ -/* utils::Node add_node(opmeta::OperatorParameters const &); */ -/* void add_edge(utils::MultiDiEdge const &e); */ -/* void remove_node(utils::Node const &, bool purge_edges = false); */ -/* void remove_edge(utils::MultiDiEdge const &e, bool remove_node_if_unused = - * true); */ -/* bool has_edge(utils::MultiDiEdge const &e) const; */ -/* void replace_subgraph(std::unordered_set const - * ¤tNodes, */ -/* Graph const &replaceWith); */ -/* Graph subgraph(std::unordered_set const &nodes) const; */ -/* void contract_out_node(opmeta::OperatorParameters const &); */ -/* float optimal_cost() const; */ -/* std::unordered_map optimal_views() - * const; */ -/* void remove_input_nodes(); */ -/* void duplicate_input_node(opmeta::OperatorParameters const &); */ -/* void duplicate_input_nodes(); */ -/* opmeta::OperatorParameters clone_node(opmeta::OperatorParameters const &); - */ -/* std::pair> */ -/* deduplicate_input_node(opmeta::OperatorParameters const &); */ -/* std::unordered_map - * deduplicate_input_nodes(); */ -/* opmeta::OperatorParameters declone_node(opmeta::OperatorParameters const - * &); */ - -/* size_t hash(void) const; */ -/* void print(void) const; */ -/* void print_dot() const; */ -/* void print_dot(std::ostream &) const; */ - -/* bool check_correctness(void); */ -/* bool has_loop(void); */ -/* //bool map_operators_to_layers(std::vector &layers) const; */ -/* //static GraphOptimalViewSerialized */ -/* // graph_optimize_task(Legion::Task const *task, */ -/* // std::vector const - * ®ions, */ -/* // Legion::Context ctx, */ -/* // Legion::Runtime *runtime); */ -/* /1* opmeta::OperatorParameters - * find_bottleneck_node(opmeta::OperatorParameters const &sink_node, *1/ */ -/* /1* opmeta::OperatorParameters const - * &source_node) const; *1/ */ -/* void print_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy) const; */ -/* void export_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy, */ -/* std::string const &out_filename) const; */ -/* void export_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy, */ -/* DotFile &dot) const; */ - -/* /1* std::pair, std::unique_ptr> *1/ */ -/* /1* split_at_node(opmeta::OperatorParameters const &bottleneck) const; - * *1/ */ -/* /1* std::pair, std::unique_ptr> *1/ */ -/* /1* split_horizontal(opmeta::OperatorParameters const &source_node, - * opmeta::OperatorParameters const &sink_node) const; *1/ */ - -/* Graph reduced() const; */ - -/* opmeta::OperatorParameters find_sink_node() const; */ -/* opmeta::OperatorParameters find_source_node() const; */ -/* void reshape_output_tensor(opmeta::ParallelTensorShape const &shape); */ -/* std::unique_ptr */ -/* with_output_tensor_reshaped_to(opmeta::ParallelTensorShape const - * &shape) const; */ - -/* static Graph singleton(opmeta::OperatorParameters const &); */ -/* bool empty() const; */ - -/* template */ -/* T generic_optimal_cost() const; */ - -/* private: */ -/* void remove_inverse_parallel_ops(); */ -/* void replace_subgraph_with_nonempty( */ -/* std::unordered_set const ¤tNodes, - * Graph const &replaceWith); */ -/* private: */ -/* Graph(utils::AdjacencyMultiDiGraph const &, utils::bidict const &, std::shared_ptr const - * &); */ - -/* utils::AdjacencyMultiDiGraph g; */ -/* utils::bidict nodeMap; */ -/* std::shared_ptr logger; */ -/* }; */ - -/* struct GraphOptimizeResult { */ -/* tl::optional graph; */ -/* float cost; */ -/* std::unordered_map views; */ - -/* friend std::ostream &operator<<(std::ostream &, GraphOptimizeResult const - * &); */ -/* }; */ - -/* /1* namespace Utils { *1/ */ -/* /1* template <> *1/ */ -/* /1* struct GraphStructure { *1/ */ -/* /1* using G = FlexFlow::PCG::Graph; *1/ */ -/* /1* using graph_type = FlexFlow::PCG::Graph; *1/ */ -/* /1* using vertex_type = FlexFlow::PCG::Node; *1/ */ -/* /1* using edge_type = FlexFlow::PCG::Edge; *1/ */ - -/* /1* std::unordered_set get_nodes(G const &g) const { *1/ */ -/* /1* std::unordered_set nodes; *1/ */ -/* /1* for (auto const &kv : g.inEdges) { *1/ */ -/* /1* nodes.insert(kv.first); *1/ */ -/* /1* } *1/ */ -/* /1* for (auto const &kv : g.outEdges) { *1/ */ -/* /1* nodes.insert(kv.first); *1/ */ -/* /1* } *1/ */ - -/* /1* return nodes; *1/ */ -/* /1* } *1/ */ - -/* /1* std::unordered_set get_incoming_edges(G const &g, *1/ */ -/* /1* vertex_type const &n) - * const { *1/ */ -/* /1* if (g.inEdges.find(n) == g.inEdges.end()) { *1/ */ -/* /1* return {}; *1/ */ -/* /1* } else { *1/ */ -/* /1* return {g.inEdges.at(n).begin(), g.inEdges.at(n).end()}; *1/ */ -/* /1* } *1/ */ -/* /1* } *1/ */ - -/* /1* std::unordered_set get_outgoing_edges(G const &g, *1/ */ -/* /1* vertex_type const &n) - * const { *1/ */ -/* /1* if (g.outEdges.find(n) == g.outEdges.end()) { *1/ */ -/* /1* return {}; *1/ */ -/* /1* } else { *1/ */ -/* /1* return {g.outEdges.at(n).begin(), g.outEdges.at(n).end()}; *1/ */ -/* /1* } *1/ */ -/* /1* } *1/ */ - -/* /1* vertex_type get_src(G const &g, edge_type const &e) const { *1/ */ -/* /1* return e.srcOp; *1/ */ -/* /1* } *1/ */ - -/* /1* vertex_type get_dst(G const &g, edge_type const &e) const { *1/ */ -/* /1* return e.dstOp; *1/ */ -/* /1* } *1/ */ - -/* /1* void set_src(G const &g, edge_type &e, vertex_type const &n) const { - * *1/ */ -/* /1* e.srcOp = n; *1/ */ -/* /1* } *1/ */ - -/* /1* void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - * *1/ */ -/* /1* e.dstOp = n; *1/ */ -/* /1* } *1/ */ -/* /1* }; *1/ */ - -/* size_t dp_state_hash(Graph const *graph, */ -/* opmeta::OperatorParameters const &sink_node, */ -/* MachineView const &sink_view, */ -/* opmeta::OperatorParameters const &source_node, */ -/* MachineView const &source_view, */ -/* MachineResource const &resource); */ - -/* // template <> */ -/* // struct invalid_node> { */ -/* // using G = Graph; */ -/* // using Structure = GraphStructure; */ -/* // using vertex_type = typename Structure::vertex_type; */ -/* // */ -/* // vertex_type operator()() const { */ -/* // return vertex_type::INVALID_NODE; */ -/* // } */ -/* // }; */ -/* // */ -/* // template <> */ -/* // struct invalid_node, GraphStructure>> { - */ -/* // Node operator()() const { */ -/* // return Node::INVALID_NODE; */ -/* // } */ -/* // }; */ - -/* /1* } // namespace Utils *1/ */ -/* } // namespace ffc */ -/* } // namespace FlexFlow */ - -#endif diff --git a/lib/compiler/src/old/graph_structures.h b/lib/compiler/src/old/graph_structures.h deleted file mode 100644 index 8b921794e1..0000000000 --- a/lib/compiler/src/old/graph_structures.h +++ /dev/null @@ -1,269 +0,0 @@ -#ifndef _GRAPH_STRUCTURES_H -#define _GRAPH_STRUCTURES_H - -#include "basic_graph.h" - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template -struct ReverseStructure { - using graph_type = typename BaseStructure::graph_type; - using G = graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using edge_type = typename BaseStructure::edge_type; - - std::unordered_set get_nodes(G const &g) const { - return this->base.get_nodes(g); - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - return this->base.get_outgoing_edges(g, n); - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - return this->base.get_incoming_edges(g, n); - } - - vertex_type get_src(G const &g, edge_type const &e) const { - return this->base.get_dst(g, e); - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - return this->base.get_src(g, e); - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_dst(g, e, n); - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_src(g, e, n); - } - - BaseStructure base; -}; - -template -struct UndirectedEdge { - union Edge { - NotReversed not_reversed; - Reversed reversed; - - Edge() {} - }; - - bool is_reversed; - Edge edge; - - UndirectedEdge() {} - - bool operator==(UndirectedEdge const &other) const { - if (other.is_reversed != this->is_reversed) { - return false; - } - if (this->is_reversed) { - return this->edge.reversed == other.edge.reversed; - } else { - return this->edge.not_reversed == other.edge.not_reversed; - } - } -}; - -template > -struct UndirectedStructure { - using graph_type = typename BaseStructure::graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using not_reversed_edge_type = typename BaseStructure::edge_type; - using reversed_edge_type = - typename ReverseStructure::edge_type; - using edge_type = UndirectedEdge; - - std::unordered_set get_nodes(G const &g) const { - return this->base.get_nodes(g); - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - std::unordered_set incoming; - auto base_edges = this->base.get_incoming_edges(g, n); - auto reversed_edges = this->reversed.get_incoming_edges(g, n); - - for (auto const &e : base_edges) { - edge_type lifted; - lifted.is_reversed = false; - lifted.edge.not_reversed = e; - incoming.insert(lifted); - } - - for (auto const &e : reversed_edges) { - edge_type lifted; - lifted.is_reversed = true; - lifted.edge.reversed = e; - incoming.insert(lifted); - } - - return incoming; - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - std::unordered_set outgoing; - auto base_edges = this->base.get_outgoing_edges(g, n); - auto reversed_edges = this->reversed.get_outgoing_edges(g, n); - - for (auto const &e : base_edges) { - edge_type lifted; - lifted.is_reversed = false; - lifted.edge.not_reversed = e; - outgoing.insert(lifted); - } - - for (auto const &e : reversed_edges) { - edge_type lifted; - lifted.is_reversed = true; - lifted.edge.reversed = e; - outgoing.insert(lifted); - } - - return outgoing; - } - - vertex_type get_src(G const &g, edge_type const &e) const { - if (e.is_reversed) { - return this->reversed.get_src(g, e.edge.reversed); - } else { - return this->base.get_src(g, e.edge.not_reversed); - } - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - if (e.is_reversed) { - return this->reversed.get_dst(g, e.edge.reversed); - } else { - return this->base.get_dst(g, e.edge.not_reversed); - } - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - if (e.is_reversed) { - this->reversed.set_src(g, e.edge.reversed, n); - } else { - this->base.set_src(g, e.edge.not_reversed, n); - } - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - if (e.is_reversed) { - this->reversed.set_src(g, e.edge.reversed, n); - } else { - this->base.set_src(g, e.edge.not_reversed, n); - } - } - - BaseStructure base; - ReverseStructure reversed; -}; - -template > -struct invalid_node; - -template , - typename Invalid = invalid_node> -struct MultisourceGraphStructure { - using graph_type = typename BaseStructure::graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using edge_type = typename BaseStructure::edge_type; - - std::unordered_set get_nodes(G const &g) const { - Invalid invalid; - - std::unordered_set nodes = this->base.get_nodes(g); - nodes.insert(invalid()); - return nodes; - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - Invalid invalid; - - if (n == invalid()) { - return {}; - } - - std::unordered_set edges = this->base.get_incoming_edges(g, n); - if (edges.empty()) { - edge_type e; - this->base.set_src(g, e, invalid()); - this->base.set_dst(g, e, n); - return {e}; - } - - return edges; - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - Invalid invalid; - - if (n == invalid()) { - std::unordered_set edges; - for (auto const &node : this->base.get_nodes(g)) { - if (this->base.get_incoming_edges(g, node).empty()) { - edge_type e; - this->base.set_src(g, e, invalid()); - this->base.set_dst(g, e, node); - edges.insert(e); - } - } - return edges; - } - - return this->base.get_outgoing_edges(g, n); - } - - vertex_type get_src(G const &g, edge_type const &e) const { - return this->base.get_src(g, e); - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - return this->base.get_dst(g, e); - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_src(g, e, n); - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_dst(g, e, n); - } - - BaseStructure base; -}; -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -namespace std { -using FlexFlow::PCG::Utils::UndirectedEdge; - -template -struct hash> { - size_t operator()(UndirectedEdge const &e) const { - size_t result; - result = std::hash()(e.is_reversed); - if (e.is_reversed) { - hash_combine(result, e.edge.reversed); - } else { - hash_combine(result, e.edge.not_reversed); - } - return result; - } -}; -} // namespace std - -#endif // _GRAPH_STRUCTURES_H diff --git a/lib/compiler/src/old/node.h b/lib/compiler/src/old/node.h deleted file mode 100644 index eb33a39ae7..0000000000 --- a/lib/compiler/src/old/node.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef _FLEXFLOW_FFC_NODE_H -#define _FLEXFLOW_FFC_NODE_H - -#include - -#include "op-attrs/op-attrs.h" -#include "tl/optional.hpp" - -namespace FlexFlow { -namespace ffc { - -struct Node { - Node() = delete; - Node(size_t guid, PCGOperatorAttrs const &op_params); - - std::string to_string(void) const; - - using AsTuple = - std::tuple &>; - using AsConstTuple = std::tuple const &>; - - AsTuple as_tuple(); - AsConstTuple as_tuple() const; - -public: - size_t guid; - PCGOperatorAttrs op_params; - tl::optional original_guid = tl::nullopt; -}; - -bool operator==(Node const &, Node const &); -bool operator!=(Node const &, Node const &); -bool operator<(Node const &, Node const &); - -} // namespace ffc -} // namespace FlexFlow - -namespace std { -template <> -struct hash<::FlexFlow::ffc::Node> { - size_t operator()(::FlexFlow::ffc::Node const &n) const; -}; -} // namespace std - -#endif diff --git a/lib/compiler/src/old/parallel_dim_mapping_record.h b/lib/compiler/src/old/parallel_dim_mapping_record.h deleted file mode 100644 index 8e2c265489..0000000000 --- a/lib/compiler/src/old/parallel_dim_mapping_record.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef _FLEXFLOW_FFC_PARALLEL_DIM_MAPPING_RECORD_H -#define _FLEXFLOW_FFC_PARALLEL_DIM_MAPPING_RECORD_H - -#endif diff --git a/lib/compiler/src/old/search_helper.cc b/lib/compiler/src/old/search_helper.cc deleted file mode 100644 index 2e7eafa5fd..0000000000 --- a/lib/compiler/src/old/search_helper.cc +++ /dev/null @@ -1,525 +0,0 @@ -#include "search_helper.h" - -namespace FlexFlow { -namespace PCG { - -SearchHelper::SearchHelper() { - this->logger = std::unique_ptr(new RecursiveLogger("DP")); -} - -template -T SearchHelper::execute_sequence_split(std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - SequenceSplit const &bn) const { - return sequence_cost( - this->graph_cost(pre_graph.get(), source, bn, resources, true), - this->graph_cost(post_graph.get(), bn, sink, resources, false)); -} - -template -T SearchHelper::find_optimal_sequence_graph_time( - Graph const *g, - Node const &bn_node, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const { - std::unique_ptr pre_graph; - std::unique_ptr post_graph; - std::tie(pre_graph, post_graph) = g->split_at_node(bn_node); - - T optimal = this->infinity(); - - std::vector valid_views = - this->get_valid_machine_views(bn_node.op_params, resources); - // A Corner Case: - // If bn_node is a parallel_op and an input to sink_node, - // Add sink_node's view to the list, since sink_node's view - // may not be a valid view for resources, but UniFlow support - // this case since parallel_op does not trigger computation - if (is_parallel_op(bn_node.op_params)) { - bool found = false; - auto const &inList = g->inEdges.find(sink.node)->second; - for (auto const &e : inList) { - if (e.srcOp == bn_node) { - found = true; - break; - } - } - if (found) { - for (int j = 0; j < bn_node.ptr->numOutputs; j++) { - if (!bn_node.ptr->outputs[j]->is_valid_machine_view(sink.view)) { - found = false; - } - } - } - if (found) { - valid_views.push_back(sink.view); - } - } - - if (valid_views.empty()) { - return optimal; - } - - float optimal_cost = std::numeric_limits::infinity(); - MachineView best_view; - - for (MachineView const &bn_view : valid_views) { - float cost = this->execute_sequence_split( - pre_graph, post_graph, source, sink, resources, {bn_node, bn_view}); - - if (cost < optimal_cost) { - best_view = bn_view; - optimal_cost = cost; - } - } - - if (optimal_cost != std::numeric_limits::infinity()) { - optimal = this->execute_sequence_split( - pre_graph, post_graph, source, sink, resources, {bn_node, best_view}); - } - - check_matches_graph(g, optimal, sink.node); - - return optimal; -} - -template -T SearchHelper::execute_nonsequence_split( - std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - NonsequenceSplit const &split) const { - Graph const *first = first_graph.get(); - Graph const *second = second_graph.get(); - if (split.flip_graphs) { - std::swap(first, second); - } - switch (split.type) { - case SplitType::SEQUENTIAL: - this->logger->debug() << "Exploring sequential nonsequence split"; - return sequence_cost( - this->graph_cost(first, source, sink, resources, false), - this->graph_cost(second, source, sink, resources, false)); - case SplitType::VERTICAL: { - this->logger->debug() << "Exploring vertical nonsequence split (" - << split.param << ", " << split.flip_graphs << ")"; - MachineResource firstRes = resources, secondRes = resources; - firstRes.num_nodes = split.param; - secondRes.num_nodes = resources.num_nodes - split.param; - secondRes.start_gpu_id = - resources.start_gpu_id + resources.all_gpus_per_node * split.param; - - return parallel_cost( - this->graph_cost(first, source, sink, firstRes, false), - this->graph_cost(second, source, sink, secondRes, false)); - } - case SplitType::HORIZONTAL: { - this->logger->debug() << "Exploring horizontal nonsequence split (" - << split.param << ", " << split.flip_graphs << ")"; - MachineResource firstRes = resources, secondRes = resources; - firstRes.available_gpus_per_node = split.param; - secondRes.available_gpus_per_node = - resources.available_gpus_per_node - split.param; - secondRes.start_gpu_id = resources.start_gpu_id + split.param; - - return parallel_cost( - this->graph_cost(first, source, sink, firstRes, false), - this->graph_cost(second, source, sink, secondRes, false)); - } - default: - assert(false); - } -} - -template -T SearchHelper::find_optimal_nonsequence_graph_time( - Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const { - std::unique_ptr first_graph; - std::unique_ptr second_graph; - std::tie(first_graph, second_graph) = - g->split_horizontal(source.node, sink.node); - - std::vector potential_splits; - - for (int i = 1; i < resources.num_nodes; i++) { - potential_splits.push_back(NonsequenceSplit::vertical(i, false)); - potential_splits.push_back(NonsequenceSplit::vertical(i, true)); - } - for (int i = 1; i < resources.available_gpus_per_node; i++) { - potential_splits.push_back(NonsequenceSplit::horizontal(i, false)); - potential_splits.push_back(NonsequenceSplit::horizontal(i, true)); - } - - NonsequenceSplit best_split = NonsequenceSplit::sequential(); - float best_cost = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, best_split); - for (NonsequenceSplit const &split : potential_splits) { - float cost = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, split); - this->logger->debug() << "Found cost: " << cost; - - if (cost < best_cost) { - best_cost = cost; - best_split = split; - } - } - - switch (best_split.type) { - case SplitType::SEQUENTIAL: - this->logger->debug() << "Best split: SEQUENTIAL"; - break; - case SplitType::VERTICAL: - this->logger->debug() << "Best split: VERTICAL(" << best_split.param - << ", " << best_split.flip_graphs << ")"; - break; - case SplitType::HORIZONTAL: - this->logger->debug() << "Best split: HORIZONTAL(" << best_split.param - << ", " << best_split.flip_graphs << ")"; - break; - } - T optimal = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, best_split); - - check_matches_graph(g, optimal, sink.node); - - return optimal; -} - -std::vector SearchHelper::get_valid_machine_views( - Node const &node, MachineResource const &resource, bool log) const { - this->logger->info() << "Getting valid machine views for " - << node.to_string(); - return this->get_valid_machine_views(node.ptr, resource, log); -} - -std::vector SearchHelper::get_valid_machine_views( - Op const *op, MachineResource const &resource, bool log) const { - std::vector const *cached_op_views = NULL; - std::vector valid_views; - - auto const &iter = cached_operator_valid_views.find(op->op_guid); - if (iter != cached_operator_valid_views.end()) { - cached_op_views = iter->second.get(); - } else { - auto to_cache = std::unique_ptr>( - new std::vector()); - if (log) { - this->logger->info() << "Considering a total of " - << this->model->all_valid_views.size() - << " potential valid views"; - } - for (size_t i = 0; i < this->model->all_valid_views.size(); i++) { - bool valid = true; - for (int j = 0; j < op->numOutputs; j++) { - if (!op->outputs[j]->is_valid_machine_view( - this->model->all_valid_views[i])) { - valid = false; - { - MachineView const &view = this->model->all_valid_views[i]; - std::ostringstream oss; - oss << "[" << view.ndims << "]("; - for (int i = 0; i < view.ndims; i++) { - oss << view.dim[i] << "/" << view.stride[i]; - if (i != view.ndims - 1) { - oss << " "; - } - } - oss << ")"; - if (log) { - this->logger->info() << "Rejecting machine view: " << oss.str(); - } - } - break; - } - } - if (valid) { - { - MachineView const &view = this->model->all_valid_views[i]; - std::ostringstream oss; - oss << "[" << view.ndims << "]("; - for (int i = 0; i < view.ndims; i++) { - oss << view.dim[i] << "/" << view.stride[i]; - if (i != view.ndims - 1) { - oss << " "; - } - } - oss << ")"; - if (log) { - this->logger->info() << "Accepting machine view: " << oss.str(); - } - } - to_cache->push_back(this->model->all_valid_views[i]); - } - } - cached_operator_valid_views[op->op_guid] = std::move(to_cache); - cached_op_views = cached_operator_valid_views.at(op->op_guid).get(); - } - if (log) { - this->logger->info() << "Found " << cached_op_views->size() - << " cached op views"; - } - for (size_t i = 0; i < cached_op_views->size(); i++) { - MachineView view = (*cached_op_views)[i]; - if (view.device_type == MachineView::GPU) { - view.start_device_id = resource.start_gpu_id; - } else if (view.device_type == MachineView::CPU) { - view.start_device_id = resource.start_cpu_id; - } else { - assert(false); - } - if (resource.is_valid_machine_view(view)) { - valid_views.push_back(view); - } - } - return valid_views; -} - -template <> -bool SearchHelper::is_invalid(float const &cost) const { - return cost == std::numeric_limits::infinity(); -} - -template <> -bool SearchHelper::is_invalid( - GraphCostResult const &cost) const { - return cost.cost == std::numeric_limits::infinity(); -} - -/** - * @brief Asserts that the results of graph optimization are valid for the graph - * - * @param g the graph to check against - * @param r the results to check - * @param sink the sink node of the graph g - * @param include_sink whether or not to include the sink node - */ -template <> -void SearchHelper::check_matches_graph( - Graph const *g, GraphCostResult const &r, Node const &sink) const { - using FlexFlow::PCG::Utils::nodes; - - if (this->is_invalid(r)) { - return; - } - - std::unordered_set g_nodes = nodes(*g); - g_nodes.erase(sink); - - std::unordered_set r_nodes; - for (auto const &kv : r.views) { - r_nodes.insert(kv.first); - } - - assert(g_nodes == r_nodes); -} - -template <> -void SearchHelper::check_matches_graph(Graph const *g, - float const &r, - Node const &sink) const {} - -template <> -std::pair - SearchHelper::try_get_cost_from_cache(size_t hash) const { - if (this->cached_graph_costs.find(hash) == this->cached_graph_costs.end()) { - return {false, std::numeric_limits::infinity()}; - } else { - return {true, this->cached_graph_costs.at(hash)}; - } -} - -template <> -std::pair - SearchHelper::try_get_cost_from_cache(size_t hash) const { - return {false, GraphCostResult::invalid()}; -} - -template <> -void SearchHelper::try_cache_result(size_t hash, - float const &value) const { - this->logger->debug() << "cached_graph_costs[" << hash << "] = " << value; - this->cached_graph_costs[hash] = value; -} - -template <> -void SearchHelper::try_cache_result( - size_t hash, GraphCostResult const &value) const { - this->logger->debug() << "cached_graph_costs[" << hash << "=" << value.cost - << "]"; - this->cached_graph_costs[hash] = value.cost; -} - -template <> -float SearchHelper::infinity() const { - return std::numeric_limits::infinity(); -} - -template <> -GraphCostResult SearchHelper::infinity() const { - return {std::numeric_limits::infinity(), {}}; -} - -template <> -float SearchHelper::empty() const { - return 0.0f; -} - -template <> -GraphCostResult SearchHelper::empty() const { - return {0.0f, {}}; -} - -template -T SearchHelper::estimate_xfer_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink) const { - T result = this->empty(); - - if (source.node != Node::INVALID_NODE) { - auto const &inList = graph->inEdges.find(sink.node)->second; - float op_cost = 0.0f; - for (auto const &it2 : inList) { - assert(it2.srcOp == source.node); - assert(sink.node.ptr->inputs[it2.dstIdx]->is_valid_machine_view( - source.view)); - - float estimated_xfer_cost = this->model->simulator->estimate_xfer_cost( - sink.node.ptr, it2.dstIdx, source.view, sink.view); - // printf("Estimated xfer cost from %s to %s: %fms\n", - // source.node.ptr->name, sink.node.ptr->name, estimated_xfer_cost); - op_cost += estimated_xfer_cost; - } - this->add_operator_cost(source, op_cost, &result); - } else { - Node real_source = graph->find_source_node(); - assert(real_source.ptr->op_type == OP_INPUT); - this->add_operator_cost({real_source, MachineView::NO_VIEW}, 0.0f, &result); - } - - return result; -} - -template <> -void SearchHelper::add_operator_cost(NodeAssignment const &node, - float node_cost, - float *cost) const { - *cost += node_cost; -} - -template <> -void SearchHelper::add_operator_cost( - NodeAssignment const &node, float node_cost, GraphCostResult *cost) const { - cost->cost += node_cost; - cost->views[node.node] = node.view; -} - -template <> -float SearchHelper::get_cost(float const &f) const { - return f; -} - -template <> -float SearchHelper::get_cost( - GraphCostResult const &gcr) const { - return gcr.cost; -} - -template -T SearchHelper::graph_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - bool include_sink_compute_time) const { - TAG_ENTER(this->logger); - this->logger->debug() << "sink(" << sink.node.guid << ") " - << "sink.view(" << sink.view.ndims << " " - << sink.view.start_device_id << " " << sink.view.dim[0] - << ") " - << "source(" << source.node.guid << ") " - << "source.view(" << source.view.ndims << " " - << source.view.start_device_id << " " - << source.view.dim[0] << ") " - << "resources(" << resources.num_nodes << " " - << resources.start_gpu_id << " " - << resources.available_gpus_per_node << ")"; - if (this->model->config.profiling) { - graph->print_dot(); - } - - assert(graph->inEdges.find(sink.node) != graph->inEdges.end()); - if (source.node != Node::INVALID_NODE) { - assert(graph->outEdges.find(source.node) != graph->outEdges.end()); - } - - size_t hash = dp_state_hash( - graph, sink.node, sink.view, source.node, source.view, resources); - this->logger->spew() << "hash = " << hash; - - T result; - - std::pair from_cache = this->try_get_cost_from_cache(hash); - if (from_cache.first) { - // cached_graph_costs does not include sink_compute_time - result = from_cache.second; - } else { - if (graph->inEdges.size() <= 2) { - result = this->estimate_xfer_cost(graph, source, sink); - this->logger->debug() - << "Estimated xfer cost is " << this->get_cost(result); - } else { - Node bn_node = graph->find_bottleneck_node(sink.node, source.node); - if (bn_node != Node::INVALID_NODE) { - // We found a bottleneck node - this->logger->debug() << "Found bn_node = " << bn_node.guid; - - result = this->find_optimal_sequence_graph_time( - graph, - bn_node, - {source.node, source.view}, - {sink.node, sink.view}, - resources); - } else { - // sink node must have multiple branches - // otherwise we should not be here - assert(graph->inEdges.find(sink.node)->second.size() > 1); - - result = this->find_optimal_nonsequence_graph_time( - graph, - {source.node, source.view}, - {sink.node, sink.view}, - resources); - } - } - - this->try_cache_result(hash, result); - } - - check_matches_graph(graph, result, sink.node); - - if (include_sink_compute_time) { - CostMetrics metrics = - this->model->simulator->measure_operator_cost(sink.node.ptr, sink.view); - this->logger->debug() << "Sink node cost: " - << "forward(" << metrics.forward_time << ") " - << "backward(" << metrics.backward_time << ") " - << "sync(" << metrics.sync_time << ")"; - this->add_operator_cost(sink, - metrics.forward_time + metrics.backward_time + - metrics.sync_time, - &result); - } - - return result; -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/search_helper.h b/lib/compiler/src/old/search_helper.h deleted file mode 100644 index 95350ce6af..0000000000 --- a/lib/compiler/src/old/search_helper.h +++ /dev/null @@ -1,122 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SRC_SEARCH_HELPER_H -#define _FLEXFLOW_FFC_SRC_SEARCH_HELPER_H - -#include "graph.h" -#include "split_types.h" - -namespace FlexFlow { - -struct GraphCostResult { - float cost; - std::unordered_map views; - - static GraphCostResult invalid(); - - bool operator<(GraphCostResult const &other) const; - - friend std::ostream &operator<<(std::ostream &, GraphCostResult const &); -}; - -template -T sequence_cost(T const &first, T const &second); - -template -T parallel_cost(T const &first, T const &second); - -class SearchHelper { -public: - SearchHelper(); - - template - T graph_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - bool include_sink_compute_time) const; - template - T find_optimal_sequence_graph_time(Graph const *g, - Node const &bottleneck_node, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const; - template - T find_optimal_nonsequence_graph_time(Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const; - /* void find_optimal_nonsequence_graph_views(Graph const *g, */ - /* NodeAssignment const &source, */ - /* NodeAssignment const &sink, */ - /* MachineResource const &resources, - */ - /* float optimal_cost, */ - /* std::unordered_map& optimal_views) const; */ - std::vector - get_valid_machine_views(Node const &node, - MachineResource const &resource, - bool log = false) const; - std::vector - get_valid_machine_views(PCGOperatorAttrs const &op, - MachineResource const &resource, - bool log = false) const; - - template - std::pair try_get_cost_from_cache(size_t hash) const; - - template - void try_cache_result(size_t hash, T const &value) const; - - template - T infinity() const; - - template - T empty() const; - - template - bool is_invalid(T const &) const; - - template - T estimate_xfer_cost(Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink) const; - - template - void add_operator_cost(NodeAssignment const &, float, T *) const; - - template - float get_cost(T const &) const; - - template - void check_matches_graph(Graph const *, T const &, Node const &) const; - -public: - mutable std::unique_ptr logger; - -private: - template - T execute_nonsequence_split(std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - NonsequenceSplit const &split) const; - - template - T execute_sequence_split(std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - SequenceSplit const &split) const; - -private: - mutable std::unordered_map cached_graph_costs; - mutable std::unordered_map>> - cached_operator_valid_views; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/simplification.cc b/lib/compiler/src/old/simplification.cc deleted file mode 100644 index 18fc2fb71a..0000000000 --- a/lib/compiler/src/old/simplification.cc +++ /dev/null @@ -1,189 +0,0 @@ -#include "simplification.h" -#include "spdlog/spdlog.h" -#include - -namespace FlexFlow { -namespace PCG { - -Simplifier::Simplifier(std::string const &logger_name) - : logger(spdlog::get(logger_name)) {} - -void Simplifier::simplify_parallel_ops() { - logger->debug("Trying to simplify parallel ops"); - - /* using FlexFlow::PCG::Utils::nodes; */ - /* using FlexFlow::PCG::Utils::predecessor; */ - /* using FlexFlow::PCG::Utils::predecessors; */ - /* using FlexFlow::PCG::Utils::successor; */ - - std::queue work_queue; - for (Node const &node : nodes(*this)) { - if (node.ptr->is_parallel_op()) { - work_queue.push(node); - } - } - - while (!work_queue.empty()) { - Node node = work_queue.front(); - log_simplify.debug() << "Trying to simplify starting from " - << node.to_string(); - work_queue.pop(); - - auto opt_succ = successor(*this, node); - if (!opt_succ.has_value()) { - log_simplify.debug() << "Skipping because does not have single successor"; - continue; - } - Node succ = opt_succ.value(); - if (!succ.ptr->is_parallel_op()) { - log_simplify.debug() << "Skipping because successor is not a parallel op"; - continue; - } - - std::vector node_parallel_op_info, - successor_parallel_op_info; - ((ParallelOp *)node.ptr)->append_parallel_op_info(node_parallel_op_info); - ((ParallelOp *)succ.ptr) - ->append_parallel_op_info(successor_parallel_op_info); - ParallelOpJoinResult result = try_join_parallel_ops( - node_parallel_op_info.front(), successor_parallel_op_info.front()); - - if (!result.join_did_succeed) { - log_simplify.debug() << "Skipping because join did not succeed"; - continue; - } - log_simplify.debug() << "Did join nodes"; - log_simplify.debug() << " " << node.to_string(); - log_simplify.debug() << " " << succ.to_string(); - - for (Node const &p : predecessors(*this, node)) { - if (p.ptr->is_parallel_op()) { - work_queue.push(p); - } - } - - Graph new_g(this->model); - if (result.op.has_value()) { - Node new_op = this->model->get_or_create_parallel_op_node( - node.ptr->inputs[0], result.op.value()); - work_queue.push(new_op); - new_g.add_node(new_op); - } - this->replace_subgraph({node, succ}, new_g); - } - log_simplify.debug() << "Finished simplifying parallel ops"; -} - -void Graph::simplify(SimplificationSettings const &settings) { - // Simplify the graph by eliminating reverse parallel ops - // and fusing multiple parallel ops - // old graph: e1->n1->e2->n2->en - // new graph: e1->new_node->en - // TODO: temporarily disabled graph simplification - if (settings.simplify_parallel_ops) { - this->simplify_parallel_ops(); - } - if (settings.fuse_parallel_ops) { - bool simplify = true; - while (simplify) { - simplify = false; - for (auto const &it : this->inEdges) { - if (it.first.ptr == NULL) { - continue; - } - if (it.first.ptr->is_parallel_op()) { - Node n2 = it.first; - assert(it.second.size() == 1); - Edge e2 = *it.second.begin(); - Node n1 = e2.srcOp; - // Check that n1 is a parallel op - // Check that n1 must have a single out edge - if (n1.ptr->is_parallel_op() && - this->outEdges.find(n1)->second.size() == 1) { - // merge n1 and n2 - std::vector parallel_ops; - ((ParallelOp *)n1.ptr)->append_parallel_op_info(parallel_ops); - ((ParallelOp *)n2.ptr)->append_parallel_op_info(parallel_ops); - Node new_node = model->get_or_create_fused_parallel_node( - n1.ptr->inputs[0], parallel_ops); - auto const &inList = this->inEdges.find(n1)->second; - assert(inList.size() == 1); - Edge e1 = *inList.begin(); - // Update graph by adding edges - this->add_edge(e1.srcOp, new_node, e1.srcIdx, 0); - this->remove_edge(e1); - this->remove_edge(e2); - // make a copy of outList - if (this->outEdges.find(n2) != this->outEdges.end()) { - auto const outList = this->outEdges.find(n2)->second; - for (auto const &e : outList) { - this->add_edge(new_node, e.dstOp, 0, e.dstIdx); - this->remove_edge(e); - } - } - simplify = true; - } - } - if (simplify) { - break; - } - } - } - } - - if (settings.remove_trailing_parallel_ops) { - // Remove final parallel ops - std::vector candidates; - for (auto const &it : this->outEdges) { - if (it.second.size() == 0 && it.first.ptr->op_type != OP_REDUCTION && - it.first.ptr->op_type != OP_FUSED_PARALLEL && - it.first.ptr->is_parallel_op()) { - candidates.push_back(it.first); - } - } - size_t index = 0; - while (index < candidates.size()) { - Node parallel_op = candidates[index++]; - auto const &inList = this->inEdges.find(parallel_op)->second; - assert(inList.size() == 1); - Edge e = *inList.begin(); - this->remove_edge(e); - if (this->outEdges.find(e.srcOp)->second.size() == 0 && - e.srcOp.ptr->is_parallel_op()) { - candidates.push_back(e.srcOp); - } - } - } - - if (settings.remove_noops) { - // Remove NoOps - std::vector noop_nodes; - for (auto const &it : this->inEdges) { - if (it.first.ptr == NULL) { - continue; - } - if (it.first.ptr->op_type == OP_NOOP) { - noop_nodes.push_back(it.first); - } - } - size_t index = 0; - while (index < noop_nodes.size()) { - Node noop = noop_nodes[index++]; - auto const &inList = this->inEdges.find(noop)->second; - assert(inList.size() == 1); - Edge in_edge = *inList.begin(); - // make a copy of outList - if (this->outEdges.find(noop) != this->outEdges.end()) { - auto const outList = this->outEdges.find(noop)->second; - for (auto const &e : outList) { - this->add_edge(in_edge.srcOp, e.dstOp, in_edge.srcIdx, e.dstIdx); - this->remove_edge(e); - } - } - this->remove_edge(in_edge); - } - } -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/simplification.h b/lib/compiler/src/old/simplification.h deleted file mode 100644 index d83c16eb91..0000000000 --- a/lib/compiler/src/old/simplification.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SIMPLIFICATION_H -#define _FLEXFLOW_FFC_SIMPLIFICATION_H - -#include "graph.h" -#include "spdlog/spdlog.h" -#include - -namespace FlexFlow { -namespace PCG { - -struct SimplificationSettings { - bool simplify_parallel_ops = false; - bool fuse_parallel_ops = false; - bool remove_trailing_parallel_ops = false; - bool remove_noops = false; -}; - -class Simplifier { -public: - Simplifier(std::string const &logger_name); - - Graph const &simplify(SimplificationSettings const &, Graph const &); - -private: - void simplify_parallel_ops(); - -private: - std::shared_ptr logger; -}; - -} // namespace PCG -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/split_types.cc b/lib/compiler/src/old/split_types.cc deleted file mode 100644 index e9648344d4..0000000000 --- a/lib/compiler/src/old/split_types.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "split_types.h" - -namespace FlexFlow { -namespace PCG { - -/*static*/ -NonsequenceSplit NonsequenceSplit::sequential() { - NonsequenceSplit s; - s.type = SplitType::SEQUENTIAL; - s.flip_graphs = false; - - return s; -} - -/*static*/ -NonsequenceSplit NonsequenceSplit::vertical(int param, bool flip_graphs) { - NonsequenceSplit s; - s.type = SplitType::VERTICAL; - s.param = param; - s.flip_graphs = flip_graphs; - - return s; -} - -/*static*/ -NonsequenceSplit NonsequenceSplit::horizontal(int param, bool flip_graphs) { - NonsequenceSplit s; - s.type = SplitType::HORIZONTAL; - s.param = param; - s.flip_graphs = flip_graphs; - - return s; -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/split_types.h b/lib/compiler/src/old/split_types.h deleted file mode 100644 index 3c49ad5b7a..0000000000 --- a/lib/compiler/src/old/split_types.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SPLIT_TYPES_H -#define _FLEXFLOW_FFC_SPLIT_TYPES_H - -#include "node.h" -#include "pcg/machine_view.h" - -namespace FlexFlow { -namespace PCG { - -enum class SplitType { SEQUENTIAL, VERTICAL, HORIZONTAL }; - -struct NonsequenceSplit { - SplitType type; - int param; - bool flip_graphs; - - static NonsequenceSplit sequential(); - static NonsequenceSplit vertical(int param, bool flip_graphs); - static NonsequenceSplit horizontal(int param, bool flip_graphs); -}; - -struct NodeAssignment { - Node node; - MachineView view; -}; - -using SequenceSplit = NodeAssignment; - -} // namespace PCG -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/substitution.cc b/lib/compiler/src/old/substitution.cc deleted file mode 100644 index 9f8381093c..0000000000 --- a/lib/compiler/src/old/substitution.cc +++ /dev/null @@ -1,3733 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "substitution.h" -#include "graph.h" -#include "graph_structures.h" -#include "op-meta/op-meta.h" -#include "parallel_ops/combine.h" -#include "parallel_ops/fused_parallel_op.h" -#include "parallel_ops/partition.h" -#include "parallel_ops/reduction.h" -#include "parallel_ops/replicate.h" -#include "utils/dot/dot_file.h" -#include -#include - -using namespace ::FlexFlow::substitutions; - -namespace FlexFlow { -namespace ffc { - -const TensorX TensorX::NO_TX = TensorX(); - -bool TensorX::operator==(TensorX const &other) const { - return this->op == other.op && this->idx == other.idx; -} - -bool TensorX::operator!=(TensorX const &other) const { - return !this->operator==(other); -} - -Rule create_combine_inception(int num_convs, int num_dims, int num_parts); -Rule create_combine_concat(int num_inputs, int num_dims, int num_parts); -Rule create_replicate_linear_combine(int num_dims, - int num_parts, - ActiMode activation, - bool use_bias); -Rule create_partition_linear_combine(int num_dims, - int num_parts, - ActiMode activation, - bool use_bias); -Rule create_partition_conv2d_combine(int num_dims, int num_parts); -Rule create_partition_attention_combine(int num_heads, int num_parts); -Rule create_replicate_attention_reduce(int num_heads, int num_parts); -Rule create_partition_add_combine(int parallel_dim, int num_parts); -Rule create_partition_relu_combine(int parallel_dim, int num_parts); -Rule create_partition_concat_combine(int num_inputs, - int concat_dim, - int parallel_dim, - int num_parts); -Rule create_partition_softmax_combine(int softmax_dim, - int part_dim, - int num_parts); -Rule leading_relu_branch_combine(int parallel_dim, - int num_parts, - int num_combines); -Rule leading_relu_branch_partition(int parallel_dim, - int num_parts, - int num_partitions); -Rule create_linear_relu_merge(int num_dims, bool use_bias); - -PMConstraint::PMConstraint(Compare c, PMParameter p, int v) - : comp(c), para(p), value(v) {} - -TNConstraint::TNConstraint(Compare c, TNParameter p, DIMParameter d, int v) - : singlePara(true), comp(c), para1(p), dim1(d), value(v) {} - -TNConstraint::TNConstraint( - Compare c, TNParameter p1, DIMParameter d1, TNParameter p2, DIMParameter d2) - : singlePara(false), comp(c), para1(p1), para2(p2), dim1(d1), dim2(d2) {} - -tl::optional TensorX::to_tensor(GraphXfer const *xfer) const { - if (op != NULL) { - assert(op->mapOp.ptr != NULL); - return op->mapOp.ptr->outputs[idx]; - } else { - auto const &it = xfer->mappedInputs.find(idx); - if (it == xfer->mappedInputs.end()) { - return tl::nullopt; - } - assert(it != xfer->mappedInputs.end()); - Node op = it->second.first; - int outIdx = it->second.second; - return op.ptr->outputs[outIdx]; - } -} - -OpX::OpX(const OperatorType _type, - int num_inputs, - int num_outputs, - TensorX const &input0, - TensorX const &input1, - TensorX const &input2, - TensorX const &input3) - : type(_type), mapOp(Node::INVALID_NODE), matchOpX(NULL) { - TensorX all_inputs[MAX_NUM_INPUTS]; - all_inputs[0] = input0; - all_inputs[1] = input1; - all_inputs[2] = input2; - all_inputs[3] = input3; - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(all_inputs[i]); - } - for (int i = 0; i < num_outputs; i++) { - TensorX out(this, i); - outputs.push_back(out); - } -} - -OpX::OpX(const OperatorType _type, - int num_inputs, - int num_outputs, - TensorX const *input_array) - : type(_type), mapOp(Node::INVALID_NODE), matchOpX(NULL) { - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(input_array[i]); - } - for (int i = 0; i < num_outputs; i++) { - TensorX out(this, i); - outputs.push_back(out); - } -} - -bool OpX::add_pm_constraint(Compare comp, PMParameter para, int value) { - PMConstraint pmc(comp, para, value); - pmConstraints.push_back(pmc); - return true; -} - -bool OpX::add_input_constraint(Compare comp, - TNParameter para, - DIMParameter dim, - int value) { - TNConstraint tnc(comp, para, dim, value); - tnConstraints.push_back(tnc); - return true; -} - -bool OpX::add_input_constraint(Compare comp, - TNParameter para1, - DIMParameter dim1, - TNParameter para2, - DIMParameter dim2) { - TNConstraint tnc(comp, para1, dim1, para2, dim2); - tnConstraints.push_back(tnc); - return true; -} - -bool OpX::get_pm_constraint(PMParameter para, int &value) const { - for (size_t i = 0; i < pmConstraints.size(); i++) { - if ((pmConstraints[i].comp == COMPARE_EQ) && - (pmConstraints[i].para == para)) { - value = pmConstraints[i].value; - return true; - } - } - return false; -} - -GraphXfer::GraphXfer(FFModel *_model) : model(_model), tensorId(10) {} - -TensorX GraphXfer::new_tensor(void) { - TensorX t; - t.op = NULL; - t.idx = tensorId++; - return t; -} - -bool GraphXfer::map_output(TensorX const &src, TensorX const &dst) { - mappedOutputs[src] = dst; - return true; -} - -bool GraphXfer::can_match(OpX *srcOp, Node const &op, Graph const *graph) { - if (srcOp->type != op.ptr->op_type) { - return false; - } - // check num input tensors - if ((int)srcOp->inputs.size() != op.ptr->numInputs) { - return false; - } - // check pmConstraints - for (size_t i = 0; i < srcOp->pmConstraints.size(); i++) { - PMConstraint pmc = srcOp->pmConstraints[i]; - int actValue = 0; - assert(op.ptr->get_int_parameter(pmc.para, &actValue)); - // printf("pmc[%d] para(%d) comp(%d) value(%d) actValue(%d)\n", - // i, pmc.para, pmc.comp, pmc.value, actValue); - switch (pmc.comp) { - case COMPARE_EQ: { - if (actValue != pmc.value) { - return false; - } - break; - } - case COMPARE_NE: { - if (actValue == pmc.value) { - return false; - } - break; - } - case COMPARE_LT: { - if (actValue >= pmc.value) { - return false; - } - break; - } - case COMPARE_LE: { - if (actValue > pmc.value) { - return false; - } - break; - } - case COMPARE_GT: { - if (actValue <= pmc.value) { - return false; - } - break; - } - case COMPARE_GE: { - if (actValue < pmc.value) { - return false; - } - break; - } - default: - assert(false); - } - } - // check inputs - std::map> newMapInputs; - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // input tensor - std::multimap>::const_iterator it; - it = mappedInputs.find(in.idx); - if (it != mappedInputs.end()) { - Node mappedOp = it->second.first; - int mappedIdx = it->second.second; - if (!(graph->has_edge(mappedOp, op, mappedIdx, i))) { - return false; - } - } else { - std::map>::const_iterator newit; - newit = newMapInputs.find(in.idx); - if (newit != newMapInputs.end()) { - Node mappedOp = newit->second.first; - int mappedIdx = newit->second.second; - if (!(graph->has_edge(mappedOp, op, mappedIdx, i))) { - return false; - } - } else { - auto const &list = graph->inEdges.find(op)->second; - for (auto const &e : list) { - if (e.dstIdx == (int)i) { - newMapInputs.insert( - std::make_pair(in.idx, std::make_pair(e.srcOp, e.srcIdx))); - } - } - } - // Do nothing when we check the match - /* mapped in.idx to an op - std::set list = graph->inEdges.find(op)->second; - std::set::const_iterator it2; - for (it2 = list.begin(); it2 != list.end(); it2++) { - Edge e = *it2; - if (e.dstIdx == i) - mappedInputs[in.idx] = std::make_pair(e.srcOp, e.srcIdx); - }*/ - } - } else { - // intermediate tensor - assert(in.op->mapOp != Node::INVALID_NODE); - if (!(graph->has_edge(in.op->mapOp, op, in.idx, i))) { - return false; - } - } - } - // check tnConstraints - for (size_t i = 0; i < srcOp->tnConstraints.size(); i++) { - TNConstraint tnc = srcOp->tnConstraints[i]; - int actValue = 0, expValue = 0; - if (tnc.singlePara) { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - expValue = tnc.value; - } else { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - assert(op.ptr->get_tensor_parameter(tnc.para2, tnc.dim2, &expValue)); - } - switch (tnc.comp) { - case COMPARE_EQ: { - if (actValue != expValue) { - return false; - } - break; - } - case COMPARE_NE: { - if (actValue == expValue) { - return false; - } - break; - } - case COMPARE_LT: { - if (actValue >= expValue) { - return false; - } - break; - } - case COMPARE_LE: { - if (actValue > expValue) { - return false; - } - break; - } - case COMPARE_GT: { - if (actValue <= expValue) { - return false; - } - break; - } - case COMPARE_GE: { - if (actValue < expValue) { - return false; - } - break; - } - default: - assert(false); - } - } - return true; -} - -void GraphXfer::match(OpX *srcOp, Node const &op, Graph const *graph) { - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // Update mappedInputs - auto const &list = graph->inEdges.find(op)->second; - for (auto const &e : list) { - if (e.dstIdx == (int)i) { - mappedInputs.insert( - std::make_pair(in.idx, std::make_pair(e.srcOp, e.srcIdx))); - } - } - } - } - // Map srcOp to Op - srcOp->mapOp = op; - mappedOps[op] = srcOp; -} - -void GraphXfer::unmatch(OpX *srcOp, Node const &op, Graph const *graph) { - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - log_xfer_matches.spew() << "umatch iteration " << i; - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // Update mappedInputsa - std::multimap>::iterator it; - log_xfer_matches.spew() << "Starting find"; - it = mappedInputs.find(in.idx); - log_xfer_matches.spew() << "Finished find"; - if (it != mappedInputs.end()) { - mappedInputs.erase(it); - } - } - } - log_xfer_matches.spew() << "Finished the unmatch loop"; - // Unmap op - mappedOps.erase(op); - srcOp->mapOp.guid = 0; - srcOp->mapOp.ptr = NULL; - log_xfer_matches.spew() << "Returning from unmatch"; -} - -GraphXferMatch::GraphXferMatch(GraphXfer const *xfer) : xfer(xfer) {} - -void GraphXferMatch::add_mapping(Node const &node, OpX *opx) { - this->nodeToOpX[node] = opx; - this->opXToNode[opx] = node; -} - -void GraphXferMatch::add_mapping(OpX *opx, Node const &node) { - this->add_mapping(node, opx); -} - -void GraphXferMatch::add_output_mapping(TensorX const &src, - TensorX const &dst) { - this->mappedOutputs[src] = dst; -} - -OpX *GraphXferMatch::at(Node const &n) const { - return this->nodeToOpX.at(n); -} - -Node GraphXferMatch::at(OpX *opx) const { - return this->opXToNode.at(opx); -} - -void GraphXferMatch::set_graph(Graph const *g) { - this->graph_hash = g->hash(); -} - -bool GraphXferMatch::containsNode(Graph const *g, Node const &n) const { - assert(g->hash() == this->graph_hash); - - return this->nodeToOpX.find(n) != this->nodeToOpX.end(); -} - -bool GraphXferMatch::containsEdge(Graph const *g, Edge const &e) const { - assert(g->hash() == this->graph_hash); - - bool contains_src = this->containsNode(g, e.srcOp); - bool contains_dst = this->containsNode(g, e.dstOp); - - return contains_src && contains_dst; -} - -GraphXfer const *GraphXferMatch::get_xfer() const { - return this->xfer; -} - -std::unordered_set GraphXferMatch::get_nodes() const { - std::unordered_set nodes; - for (auto const &kv : nodeToOpX) { - nodes.insert(kv.first); - } - - return nodes; -} - -GraphXferMatch GraphXfer::get_match_record(Graph const *g) const { - GraphXferMatch match(this); - - for (auto const &kv : this->mappedOps) { - match.add_mapping(kv.first, kv.second); - } - - for (auto const &kv : this->mappedOutputs) { - match.add_output_mapping(kv.first, kv.second); - } - - match.set_graph(g); - - return match; -} - -void GraphXfer::find_matches(Graph const *graph, - std::vector &matches) { - this->find_matches(0, graph, matches); -} - -void GraphXfer::find_matches(int depth, - Graph const *graph, - std::vector &matches) { - log_xfer_matches.spew() << "find_matches at depth: " << depth; - if (depth >= (int)srcOps.size()) { - log_xfer_matches.spew() << "Achieved adequate depth"; - // Create dst operators - bool pass = true; - for (OpX *dstOp : this->dstOps) { - pass &= create_new_operator(dstOp, dstOp->mapOp); - if (!pass) { - break; - } - } - log_xfer_matches.spew() << "Completed create dst operators"; - if (!pass) { - log_xfer_matches.spew() << "Did not pass. Returning."; - return; - } - log_xfer_matches.spew() << "Checking external edges"; - // Check that output tensors with external edges are mapped - for (auto const &opIt : mappedOps) { - auto const &list = graph->outEdges.at(opIt.first); - for (auto const &e : list) { - if (mappedOps.find(e.dstOp) == mappedOps.end()) { - // dstOp is external, (srcOp, srcIdx) must be in mappedOutputs - TensorX srcTen; - srcTen.op = opIt.second; - srcTen.idx = e.srcIdx; - if (mappedOutputs.find(srcTen) == mappedOutputs.end()) { - pass = false; - return; - } - } - } - } - log_xfer_matches.spew() << "Completed checking external edges"; - // Generate a new graph by applying xfer rule - log_xfer_matches.spew() << "Creating new graph"; - SimplificationSettings - settings; // leave everything disabeld since we don't care about cost - Graph *newGraph = this->create_new_graph(graph, settings); - log_xfer_matches.spew() << "Completed creating new graph"; - - // Check that the new graph should not have any loop - log_xfer_matches.spew() << "Checking for loop"; - if (newGraph->has_loop()) { - printf("Found a new graph with LOOP!!!!\n"); - newGraph->print(); - delete newGraph; - return; - } - log_xfer_matches.spew() << "Finished checking for loop"; - // TODO: remove me for better performance - log_xfer_matches.spew() << "Checking correctness"; - assert(newGraph->check_correctness()); - log_xfer_matches.spew() << "Finished checking correctness"; - log_xfer_matches.spew() << "Getting match record"; - GraphXferMatch match_record = this->get_match_record(graph); - log_xfer_matches.spew() << "Finished getting match record"; - matches.push_back(match_record); - } else { - OpX *srcOp = srcOps[depth]; - for (auto const &it : graph->inEdges) { - log_xfer_matches.spew() << "Exploring node " << it.first.to_string(); - // printf("can_match(%d)\n", can_match(srcOp, it->first, graph)); - if (can_match(srcOp, it.first, graph) && - (mappedOps.find(it.first) == mappedOps.end())) { - Node op = it.first; - // Check mapOutput - this->match(srcOp, op, graph); - this->find_matches(depth + 1, graph, matches); - log_xfer_matches.spew() << "Completed find matches. Unmatching"; - this->unmatch(srcOp, op, graph); - log_xfer_matches.spew() << "Finished unmatching"; - } - } - } -} - -template -void GraphXfer::run( - int depth, - Graph *graph, - std::priority_queue, GraphComparator> - &candidates, - std::unordered_set &hashmap, - float threshold, - int maxNumOps, - SimplificationSettings const &simplification_settings, - int &num_matches_found, - int &num_matches_rejected) { - // printf("run: depth(%d) srcOps.size(%zu) graph.size(%zu) candidates(%zu)\n", - // depth, srcOps.size(), graph->inEdges.size(), candidates.size()); - if (depth >= (int)srcOps.size()) { - // Create dst operators - bool pass = true; - for (OpX *dstOp : this->dstOps) { - if (pass) { - pass &= create_new_operator(dstOp, dstOp->mapOp); - } - } - if (!pass) { - return; - } - // Check that output tensors with external edges are mapped - for (auto const &opIt : mappedOps) { - auto const &list = graph->outEdges[opIt.first]; - for (auto const &e : list) { - if (mappedOps.find(e.dstOp) == mappedOps.end()) { - // dstOp is external, (srcOp, srcIdx) must be in mappedOutputs - TensorX srcTen; - srcTen.op = opIt.second; - srcTen.idx = e.srcIdx; - if (mappedOutputs.find(srcTen) == mappedOutputs.end()) { - pass = false; - return; - } - } - } - } - // Generate a new graph by applying xfer rule - log_xfers.spew() << "Found a match for xfer: " << this->get_name(); - num_matches_found++; - Graph *newGraph = this->create_new_graph(graph, simplification_settings); - // Check that the new graph should not have any loop - if (newGraph->has_loop()) { - printf("Found a new graph with LOOP!!!!\n"); - newGraph->print(); - delete newGraph; - return; - } - // TODO: remove me for better performance - assert(newGraph->check_correctness()); - if (newGraph->optimal_cost() < threshold && - (int)newGraph->inEdges.size() < maxNumOps) { - if (hashmap.find(newGraph->hash()) == hashmap.end()) { - hashmap.insert(newGraph->hash()); - log_xfers.spew() << "Found new candidate"; - // newGraph->print_dot(); - candidates.push(newGraph); - } - } else { - num_matches_rejected++; - delete newGraph; - } - } else { - OpX *srcOp = srcOps[depth]; - for (auto const &it : graph->inEdges) { - // printf("can_match(%d)\n", can_match(srcOp, it->first, graph)); - if (can_match(srcOp, it.first, graph) && - (mappedOps.find(it.first) == mappedOps.end())) { - Node op = it.first; - // Check mapOutput - match(srcOp, op, graph); - run(depth + 1, - graph, - candidates, - hashmap, - threshold, - maxNumOps, - simplification_settings, - num_matches_found, - num_matches_rejected); - unmatch(srcOp, op, graph); - } - } - } -} - -void Graph::reshape_output_tensor(ParallelTensorShape const &desired_shape) { - Node output_node = this->find_sink_node(); - - assert(output_node.ptr->numOutputs == 1); - ParallelTensor output_tensor = output_node.ptr->outputs[0]; - - assert(output_tensor->num_dims == desired_shape.num_dims); - - for (int i = 0; i < output_tensor->num_dims; i++) { - int current_size = output_tensor->dims[i].size; - int current_degree = output_tensor->dims[i].degree; - - int desired_size = desired_shape.dims[i].size; - int desired_degree = desired_shape.dims[i].degree; - - assert(current_size == desired_size); - - if (current_degree < desired_degree) { - // we need to partition - assert(desired_degree % current_degree == 0); - int partition_factor = desired_degree / current_degree; - - Node partition_node = model->get_or_create_node( - output_tensor, {i /*legion_dim*/, partition_factor}); - this->add_edge(output_node, partition_node, 0, 0); - - output_node = partition_node; - output_tensor = partition_node.ptr->outputs[0]; - current_degree *= partition_factor; - - } else if (current_degree > desired_degree) { - // we need to combine - assert(current_degree % desired_degree == 0); - int combine_factor = current_degree / desired_degree; - - Node combine_node = model->get_or_create_node( - output_tensor, {i /*legion_dim*/, combine_factor}); - this->add_edge(output_node, combine_node, 0, 0); - - output_node = combine_node; - output_tensor = combine_node.ptr->outputs[0]; - current_degree /= combine_factor; - } - - assert(current_degree == desired_degree); - } - - assert(output_tensor == output_node.ptr->outputs[0]); - assert(output_tensor->num_dims == desired_shape.num_dims); - for (int i = 0; i < desired_shape.num_dims; i++) { - assert(output_tensor->dims[i].size == desired_shape.dims[i].size); - assert(output_tensor->dims[i].degree == desired_shape.dims[i].degree); - } -} - -std::unique_ptr Graph::with_output_tensor_reshaped_to( - ParallelTensorShape const &shape) const { - auto g = std::unique_ptr(new Graph(*this)); - g->reshape_output_tensor(shape); - return g; -} - -/* Graph::Graph(Graph const &graph) */ -/* : Graph(&graph) */ -/* { } */ - -/* Graph::Graph(Graph const *graph) */ -/* : Graph(graph->model) */ -/* { */ -/* for (auto const &kv : graph->inEdges) { */ -/* Node const &node = kv.first; */ -/* std::unordered_set const &edge_set = kv.second; */ - -/* for (auto const &edge : edge_set) { */ -/* this->add_edge(edge.srcOp, edge.dstOp, edge.srcIdx) */ -/* } */ -/* } */ -/* } */ - -Graph *GraphXfer::create_new_graph( - Graph const *graph, SimplificationSettings const &simplification_settings) { - Graph *newGraph = new Graph(model); - // Step 1: map dst ops - std::vector::const_iterator dstIt; - // Step 2: add edges to the graph - for (auto const &opIt : graph->inEdges) { - if (mappedOps.find(opIt.first) == mappedOps.end()) { - // Unmapped ops - auto const &list = opIt.second; - for (auto const &it : list) { - if (mappedOps.find(it.srcOp) != mappedOps.end()) { - // mapped src -> unmapped dst - TensorX srcTen; - srcTen.op = mappedOps[it.srcOp]; - srcTen.idx = it.srcIdx; - assert(mappedOutputs.find(srcTen) != mappedOutputs.end()); - TensorX dstTen = mappedOutputs[srcTen]; - newGraph->add_edge(dstTen.op->mapOp, it.dstOp, dstTen.idx, it.dstIdx); - } else { - // unmapped src -> unmmaped dst - newGraph->add_edge(it.srcOp, it.dstOp, it.srcIdx, it.dstIdx); - } - } - } - } - // Step 3: add edges for mapped ops - for (dstIt = dstOps.begin(); dstIt != dstOps.end(); dstIt++) { - OpX *dstOp = *dstIt; - for (size_t i = 0; i < dstOp->inputs.size(); i++) { - if (dstOp->inputs[i].op == NULL) { - // unmapped src -> mapped dst - std::multimap>::const_iterator it = - mappedInputs.find(dstOp->inputs[i].idx); - assert(it != mappedInputs.end()); - std::pair const &srcEdge = it->second; - newGraph->add_edge(srcEdge.first, dstOp->mapOp, srcEdge.second, i); - } else { - // mapped src -> mapped dst - OpX *srcOp = dstOp->inputs[i].op; - int srcIdx = dstOp->inputs[i].idx; - newGraph->add_edge(srcOp->mapOp, dstOp->mapOp, srcIdx, i); - } - } - } - newGraph->simplify(simplification_settings); - - return newGraph; -} - -bool GraphXfer::create_new_operator(OpX const *opx, Node &op) { - ParallelTensor inputs[MAX_NUM_INPUTS]; - for (size_t i = 0; i < opx->inputs.size(); i++) { - tl::optional mapped = opx->inputs[i].to_tensor(this); - if (!mapped.has_value()) { - return false; - } - inputs[i] = mapped.value(); - } - // Check that the total degree of inputs[0] does not exceed available - // resources - if (opx->inputs.size() > 0) { - int degree = 1; - for (int i = 0; i < inputs[0]->num_dims; i++) { - degree *= inputs[0]->dims[i].degree; - } - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - return false; - } - } - int num_inputs; - if (opx->get_pm_constraint(PM_NUM_INPUTS, num_inputs) && - opx->inputs.size() != num_inputs) { - return false; - } - int num_outputs; - if (opx->get_pm_constraint(PM_NUM_OUTPUTS, num_outputs) && - opx->outputs.size() != num_outputs) { - return false; - } - switch (opx->type) { - case OP_NOOP: { - op = model->get_or_create_noop_node(inputs[0]); - break; - } - case OP_CONCAT: { - int axis; - assert(opx->get_pm_constraint(PM_AXIS, axis)); - op = model->get_or_create_node( - {std::begin(inputs), std::end(inputs)}, {axis}); - break; - } - case OP_SPLIT: { - int axis; - assert(opx->get_pm_constraint(PM_AXIS, axis)); - int num_outputs = opx->outputs.size(); - int input_size = inputs[0]->dims[axis].size; - - if (input_size % num_outputs != 0) { - op = Node::INVALID_NODE; - } else { - int split_size = input_size / num_outputs; - std::vector split_sizes(num_outputs, split_size); - assert(split_sizes.size() == num_outputs); - op = model->get_or_create_node(inputs[0], {split_sizes, axis}); - } - break; - } - case OP_EW_ADD: - case OP_EW_SUB: - case OP_EW_MUL: - case OP_EW_MAX: - case OP_EW_MIN: { - op = model->get_or_create_node({inputs[0], inputs[1]}, - {opx->type}); - break; - } - case OP_RELU: { - ElementUnaryParams params; - params.op_type = opx->type; - params.inplace = false; - params.scalar = 0.0f; - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_CONV2D: { - Conv2D *conv = (Conv2D *)opx->matchOpX->mapOp.ptr; - Conv2DParams params = conv->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_POOL2D: { - Pool2D *pool = (Pool2D *)opx->matchOpX->mapOp.ptr; - Pool2DParams params = pool->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_FLAT: { - Flat *flat = (Flat *)opx->matchOpX->mapOp.ptr; - op = model->get_or_create_node(inputs[0], {}); - break; - } - case OP_LINEAR: { - int activation; - assert(opx->matchOpX != NULL); - assert(opx->matchOpX->mapOp.ptr != NULL); - Linear *linear = (Linear *)opx->matchOpX->mapOp.ptr; - // assert(opx->get_pm_constraint(PM_OUTPUT_CHANNELS, output_channels)); - assert(opx->get_pm_constraint(PM_ACTI, activation)); - LinearParams params = linear->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_MULTIHEAD_ATTENTION: { - int num_heads; - assert(opx->matchOpX != NULL); - assert(opx->matchOpX->mapOp.ptr != NULL); - MultiHeadAttention *attn = (MultiHeadAttention *)opx->matchOpX->mapOp.ptr; - assert(opx->get_pm_constraint(PM_NUM_HEADS, num_heads)); - MultiHeadAttentionParams params = attn->get_params(); - op = model->get_or_create_node( - {inputs[0], inputs[1], inputs[2]}, params); - break; - } - case OP_SOFTMAX: { - int softmax_dim; - assert(opx->get_pm_constraint(PM_SOFTMAX_DIM, softmax_dim)); - op = model->get_or_create_node(inputs[0], {softmax_dim}); - break; - } - case OP_REPARTITION: { - int repartition_dim, repartition_degree; - assert(opx->get_pm_constraint(PM_REPARTITION_DIM, repartition_dim)); - assert(opx->get_pm_constraint(PM_REPARTITION_DEGREE, repartition_degree)); - - int degree = inputs[0]->get_total_num_parts() * repartition_degree; - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - op = Node::INVALID_NODE; - } else { - op = model->get_or_create_node( - inputs[0], {repartition_dim, repartition_degree}); - } - break; - } - case OP_REPLICATE: { - int replicate_dim, replicate_degree; - assert(opx->get_pm_constraint(PM_REPLICATE_DIM, replicate_dim)); - assert(opx->get_pm_constraint(PM_REPLICATE_DEGREE, replicate_degree)); - - if (inputs[0]->dims[replicate_dim].degree * replicate_degree > - model->config.workersPerNode) { - op = Node::INVALID_NODE; - } else { - int degree = inputs[0]->get_total_num_parts() * replicate_degree; - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - op = Node::INVALID_NODE; - } else { - op = model->get_or_create_node( - inputs[0], {replicate_dim, replicate_degree}); - } - } - break; - } - case OP_REDUCTION: { - int reduction_dim, reduction_degree; - assert(opx->get_pm_constraint(PM_REDUCTION_DIM, reduction_dim)); - assert(opx->get_pm_constraint(PM_REDUCTION_DEGREE, reduction_degree)); - op = model->get_or_create_node( - inputs[0], {reduction_dim, reduction_degree}); - break; - } - case OP_COMBINE: { - int combine_dim, combine_degree; - assert(opx->get_pm_constraint(PM_COMBINE_DIM, combine_dim)); - assert(opx->get_pm_constraint(PM_COMBINE_DEGREE, combine_degree)); - op = model->get_or_create_node(inputs[0], - {combine_dim, combine_degree}); - break; - } - default: { - std::cout << "opx->type = " << get_operator_type_name(opx->type) - << std::endl; - assert(false); - } - } - // Check operator validness - if (op == Node::INVALID_NODE) { - return false; - } - // Check tnConstraints - for (size_t i = 0; i < opx->tnConstraints.size(); i++) { - TNConstraint tnc = opx->tnConstraints[i]; - int actValue = 0, expValue = 0; - if (tnc.singlePara) { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - expValue = tnc.value; - } else { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - assert(op.ptr->get_tensor_parameter(tnc.para2, tnc.dim2, &expValue)); - } - switch (tnc.comp) { - case COMPARE_EQ: - if (actValue != expValue) { - return false; - } - break; - case COMPARE_NE: - if (actValue == expValue) { - return false; - } - break; - case COMPARE_LT: - if (actValue >= expValue) { - return false; - } - break; - case COMPARE_LE: - if (actValue > expValue) { - return false; - } - break; - case COMPARE_GT: - if (actValue <= expValue) { - return false; - } - break; - case COMPARE_GE: - if (actValue < expValue) { - return false; - } - break; - default: - assert(false); - } - } - return true; -} - -OpX *GraphXfer::create_noop(TensorX const &input) { - OpX *noop = new OpX(OP_NOOP, 1, 1, input); - return noop; -} - -OpX *GraphXfer::create_concat(TensorX const *inputs, - int num_inputs, - OpX const *_matchOpX, - int concat_dim) { - OpX *concat = new OpX(OP_CONCAT, num_inputs, 1 /*outputs*/, inputs); - concat->matchOpX = _matchOpX; - concat->add_pm_constraint(COMPARE_EQ, PM_AXIS, concat_dim); - return concat; -} - -OpX *GraphXfer::create_element_unary(TensorX const &input, - OperatorType op_type) { - OpX *eu = new OpX(op_type, 1 /*numInputs*/, 1, input); - return eu; -} - -OpX *GraphXfer::create_relu(TensorX const &input) { - return this->create_element_unary(input, OP_RELU); -} - -OpX *GraphXfer::create_element_binary(TensorX const &input1, - TensorX const &input2, - OperatorType op_type) { - OpX *eb = new OpX(op_type, 2 /*numInputs*/, 1, input1, input2); - return eb; -} - -OpX *GraphXfer::create_linear(TensorX const &input, - OpX const *_matchOpX, - int num_dims, - ActiMode acti_mode, - bool use_bias) { - // TODO FIXME @lockshaw @zhihao use_bias is completely unused - OpX *li = new OpX(OP_LINEAR, 1, 1, input); - li->matchOpX = _matchOpX; - // li->add_pm_constraint(COMPARE_EQ, PM_OUTPUT_CHANNELS, out_channels); - li->add_pm_constraint(COMPARE_EQ, PM_ACTI, acti_mode); - li->add_input_constraint(COMPARE_EQ, INPUT_0, DIM_ND, num_dims); - return li; -} - -OpX *GraphXfer::create_conv2d(TensorX const &input, OpX const *matchOpX) { - OpX *conv = new OpX(OP_CONV2D, 1, 1, input); - conv->matchOpX = matchOpX; - return conv; -} - -OpX *GraphXfer::create_pool2d(TensorX const &input, OpX const *matchOpX) { - OpX *pool = new OpX(OP_POOL2D, 1, 1, input); - pool->matchOpX = matchOpX; - return pool; -} - -OpX *GraphXfer::create_attention(TensorX const &query, - TensorX const &key, - TensorX const &value, - OpX const *_matchOpX, - int num_heads) { - OpX *attn = new OpX(OP_MULTIHEAD_ATTENTION, 3, 1, query, key, value); - attn->matchOpX = _matchOpX; - attn->add_pm_constraint(COMPARE_EQ, PM_NUM_HEADS, num_heads); - attn->add_input_constraint(COMPARE_EQ, INPUT_0, DIM_ND, 4); - attn->add_input_constraint(COMPARE_EQ, INPUT_1, DIM_ND, 4); - attn->add_input_constraint(COMPARE_EQ, INPUT_2, DIM_ND, 4); - return attn; -} - -OpX *GraphXfer::create_softmax(TensorX const &input, int softmax_dim) { - OpX *softmax = new OpX(OP_SOFTMAX, 1, 1, input); - softmax->add_pm_constraint(COMPARE_EQ, PM_SOFTMAX_DIM, softmax_dim); - return softmax; -} - -OpX *GraphXfer::create_repartition(TensorX const &input, - int repartition_dim, - int num_parts) { - OpX *part = new OpX(OP_REPARTITION, 1, 1, input); - part->add_pm_constraint(COMPARE_EQ, PM_REPARTITION_DIM, repartition_dim); - part->add_pm_constraint(COMPARE_EQ, PM_REPARTITION_DEGREE, num_parts); - return part; -} - -OpX *GraphXfer::create_replicate(TensorX const &input, - int replicate_dim, - int num_parts) { - OpX *replicate = new OpX(OP_REPLICATE, 1, 1, input); - replicate->add_pm_constraint(COMPARE_EQ, PM_REPLICATE_DIM, replicate_dim); - replicate->add_pm_constraint(COMPARE_EQ, PM_REPLICATE_DEGREE, num_parts); - return replicate; -} - -OpX *GraphXfer::create_reduction(TensorX const &input, - int reduction_dim, - int num_parts) { - OpX *reduction = new OpX(OP_REDUCTION, 1, 1, input); - reduction->add_pm_constraint(COMPARE_EQ, PM_REDUCTION_DIM, reduction_dim); - reduction->add_pm_constraint(COMPARE_EQ, PM_REDUCTION_DEGREE, num_parts); - return reduction; -} - -OpX *GraphXfer::create_combine(TensorX const &input, - int combine_dim, - int num_parts) { - OpX *part = new OpX(OP_COMBINE, 1, 1, input); - part->add_pm_constraint(COMPARE_EQ, PM_COMBINE_DIM, combine_dim); - part->add_pm_constraint(COMPARE_EQ, PM_COMBINE_DEGREE, num_parts); - return part; -} - -void Graph::print_strategy_computation_graph( - std::unordered_map const &strategy) const { - DotFile dot(std::cout); - this->export_strategy_computation_graph(strategy, dot); -} - -void Graph::export_strategy_computation_graph( - std::unordered_map const &strategy, - std::string const &out_filename) const { - DotFile dot(out_filename); - - this->export_strategy_computation_graph(strategy, dot); -} - -void Graph::export_strategy_computation_graph( - std::unordered_map const &strategy, - DotFile &dot) const { - using FlexFlow::PCG::Utils::GraphStructure; - - GraphStructure s; - - for (auto const &node : s.get_nodes(*this)) { - // Add node - if (strategy.find(node) == strategy.end()) { - // Check FusedParallel node here and print out the detailed information - if (node.ptr->op_type == OperatorType::OP_FUSED_PARALLEL) { - RecordFormatter rf; - std::vector rows{}; - - FusedParallelOp *fused_op = (FusedParallelOp *)node.ptr; - for (int i = 0; i < fused_op->num_parallel_ops; i++) { - RecordFormatter row{}; - ParallelOpInfo op_info = fused_op->parallel_ops[i]; - std::string op_type_str = get_operator_type_name(op_info.op_type); - row << op_type_str << "dim: " + std::to_string(op_info.parallel_dim) - << "degree: " + std::to_string(op_info.parallel_degree); - rows.emplace_back(row); - } - rf << node.to_string(); - for (auto &r : rows) { - rf << r; - } - dot.add_record_node(node, rf); - } else { - dot.add_node(node, {{"label", node.to_string()}}); - } - } else { - RecordFormatter rf, meta_row, machine_view_row, runtime_code, memory_code, - runtime_cost_row, memory_cost_row; - MachineView mv = strategy.at(node); - std::ostringstream oss; - CostMetrics op_cost = - this->model->simulator->measure_operator_cost(node.ptr, mv); - switch (node.ptr->op_type) { - case OP_REPARTITION: { - Repartition *rp = (Repartition *)node.ptr; - meta_row << std::to_string(rp->repartition_dim) - << std::to_string(rp->repartition_degree); - break; - } - case OP_COMBINE: { - Combine *c = (Combine *)node.ptr; - meta_row << std::to_string(c->combine_dim) - << std::to_string(c->combine_degree); - break; - } - case OP_REPLICATE: { - Replicate *r = (Replicate *)node.ptr; - meta_row << std::to_string(r->replicate_dim) - << std::to_string(r->replicate_degree); - break; - } - case OP_REDUCTION: { - Reduction *r = (Reduction *)node.ptr; - meta_row << std::to_string(r->reduction_dim) - << std::to_string(r->reduction_degree); - break; - } - default: { - if (mv.ndims == 0) { - meta_row << "N/A"; - } else { - for (int i = 0; i < mv.ndims; i++) { - meta_row << std::to_string(mv.dim[i]); - } - } - } - } - - // Fetch machine view information - for (int device_id : mv.device_ids()) { - machine_view_row << std::to_string(device_id); - } - rf << node.to_string() << std::to_string(node.guid) << meta_row - << machine_view_row; - - // get memory cost - if (this->model->config.include_costs_dot_graph) { - float input_mem = (float)op_cost.inputs_memory; - if (node.ptr->numInputs > 0) { - input_mem /= (*node.ptr->inputs)->get_total_num_parts(); - } - float output_mem = (float)op_cost.outputs_memory; - if (node.ptr->numOutputs > 0) { - output_mem /= (*node.ptr->outputs)->get_total_num_parts(); - } - float weight_mem = (float)op_cost.weights_memory; - if (node.ptr->numWeights > 0) { - weight_mem /= (*node.ptr->weights)->get_total_num_parts(); - } - - runtime_code << "fwd" - << "bwd" - << "sync" - << "secs"; - runtime_cost_row << op_cost.forward_time << op_cost.backward_time - << op_cost.sync_time; - memory_code << "in" - << "out" - << "weight" - << "bytes"; - memory_cost_row << input_mem << output_mem << weight_mem; - rf << runtime_code << runtime_cost_row << memory_code - << memory_cost_row; - } - - dot.add_record_node(node, rf); - } - - // Add edges - for (auto const &edge : s.get_incoming_edges(*this, node)) { - dot.add_edge(s.get_src(*this, edge), s.get_dst(*this, edge)); - } - } - - dot.close(); -} - -template -void create_mapping_xfers( - FFModel *model, - int degree, - std::vector &xfers, - tl::optional> dims = tl::nullopt) { - std::vector records; - T::construct_output_mappings(records); - std::unordered_map output_mappings; - - std::unordered_set all_dims; - for (ParallelDimMappingRecord const &record : records) { - assert(record.input_idx == 0); - assert(record.get_type() == MappingRecordType::INPUT_OUTPUT); - assert(record.output_idx == 0); - assert(record.operation.has_value()); - - all_dims.insert(record.input_dim); - output_mappings.insert({record.input_dim, record}); - } - - if (dims.has_value()) { - all_dims = dims.value(); - } - - for (int const input_dim : all_dims) { - int output_dim = output_mappings.at(input_dim).output_dim; - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - - OpX *original_op = subst->create_opx(input, NULL /*matchOpX*/); - subst->srcOps.push_back(original_op); - - OpX *pre; - std::string pre_name; - switch (output_mappings.at(input_dim).operation.value()) { - case MappingOperation::PARTITION: - pre = subst->create_repartition(input, input_dim, degree); - pre_name = "partition"; - break; - case MappingOperation::REPLICATE: - pre = subst->create_replicate(input, input_dim, degree); - pre_name = "replicate"; - break; - } - subst->dstOps.push_back(pre); - - OpX *new_op = - subst->create_opx(pre->outputs[0], original_op /*matchOpX*/); - subst->dstOps.push_back(new_op); - - OpX *post; - std::string post_name; - switch (output_mappings.at(input_dim).operation.value()) { - case MappingOperation::PARTITION: - post = subst->create_combine(new_op->outputs[0], output_dim, degree); - post_name = "combine"; - break; - case MappingOperation::REPLICATE: - post = subst->create_reduction(new_op->outputs[0], output_dim, degree); - post_name = "reduce"; - break; - } - subst->dstOps.push_back(post); - - subst->map_output(original_op->outputs[0], post->outputs[0]); - - std::ostringstream oss; - std::string op_type_name = get_operator_type_name(new_op->type); - std::transform(op_type_name.begin(), - op_type_name.end(), - op_type_name.begin(), - [](unsigned char c) { return std::tolower(c); }); - oss << "mapping::" << pre_name << "_" << op_type_name << "_" << post_name - << "[" - << "input_dim=" << input_dim << ",degree=" << degree << "]"; - subst->name = oss.str(); - - xfers.push_back(subst); - } -} - -std::string GraphXfer::get_name() const { - if (this->name.has_value()) { - return this->name.value(); - } else { - std::ostringstream oss; - oss << "unknown_xfer(" << this << ")"; - return oss.str(); - } -} - -/* int get_num_outputs(sl::Operator const &op) { */ -/* switch (op.op_type) { */ -/* case OP_SPLIT: */ -/* return op.at(PM_NUM_OUTPUTS).value(); */ -/* default: */ -/* return 1; */ -/* } */ -/* } */ - -/* int get_num_inputs(sl::Operator const &op) { */ -/* switch (op.op_type) { */ -/* case OP_EW_ADD: // binary ops */ -/* case OP_EW_SUB: */ -/* case OP_EW_MUL: */ -/* case OP_EW_DIV: */ -/* case OP_EW_EQUAL: */ -/* case OP_EW_GREATER: */ -/* case OP_EW_LESS: */ -/* case OP_EW_MAX: */ -/* case OP_EW_MIN: */ -/* return 2; */ -/* case OP_SPLIT: */ -/* return 1; */ -/* case OP_LINEAR: */ -/* return 1; */ -/* case OP_CONV2D: */ -/* return 1; */ -/* case OP_RELU: */ -/* case OP_IDENTITY: */ -/* case OP_SIGMOID: */ -/* case OP_TANH: */ -/* case OP_ELU: */ -/* return 1; */ -/* case OP_CONCAT: */ -/* return op.at(PM_NUM_INPUTS).value(); */ -/* case OP_INPUT: */ -/* return 0; */ -/* case OP_REPARTITION: */ -/* case OP_COMBINE: */ -/* case OP_REPLICATE: */ -/* case OP_REDUCTION: */ -/* case OP_PIPELINE: */ -/* return 1; */ -/* default: */ -/* throw std::runtime_error("Unknown num_inputs for operator " + */ -/* get_operator_type_name(op.op_type)); */ -/* } */ -/* } */ - -OpX *create_opx(sl::Operator const &op, - int parallel_degree, - TensorX const &input1, - TensorX const &input2, - TensorX const &input3, - TensorX const &input4) { - int num_inputs = get_num_inputs(op); - int num_outputs = get_num_outputs(op); - - OpX *opx = new OpX( - op.op_type, num_inputs, num_outputs, input1, input2, input3, input4); - for (sl::Parameter const &p : op.para) { - if (p.key == PM_PARALLEL_DEGREE) { - tl::optional degree_key = tl::nullopt; - switch (op.op_type) { - case OP_REPARTITION: - degree_key = PM_REPARTITION_DEGREE; - break; - case OP_COMBINE: - degree_key = PM_COMBINE_DEGREE; - break; - case OP_REDUCTION: - degree_key = PM_REDUCTION_DEGREE; - break; - case OP_REPLICATE: - degree_key = PM_REPLICATE_DEGREE; - break; - } - - if (degree_key.has_value()) { - // Assume the generator only consider a parallel degree of 2 - assert(p.value == 2); - opx->add_pm_constraint(COMPARE_EQ, degree_key.value(), parallel_degree); - } - } else if (p.key == PM_PARALLEL_DIM) { - tl::optional dim_key = tl::nullopt; - switch (op.op_type) { - case OP_REPARTITION: - dim_key = PM_REPARTITION_DIM; - break; - case OP_COMBINE: - dim_key = PM_COMBINE_DIM; - break; - case OP_REDUCTION: - dim_key = PM_REDUCTION_DIM; - break; - case OP_REPLICATE: - dim_key = PM_REPLICATE_DIM; - break; - } - - if (dim_key.has_value()) { - opx->add_pm_constraint(COMPARE_EQ, dim_key.value(), p.value); - } - } else if (p.key == PM_PAD) { - opx->add_pm_constraint(COMPARE_EQ, PM_PADDING_H, p.value); - opx->add_pm_constraint(COMPARE_EQ, PM_PADDING_W, p.value); - } else { - opx->add_pm_constraint(COMPARE_EQ, p.key, p.value); - } - } - - return opx; -} - -OpX *find_opx_with_type(std::vector const &src_ops, - OperatorType op_type) { - OpX *matchOpX = nullptr; - for (size_t k = 0; k < src_ops.size(); k++) { - if (src_ops[k]->type == op_type) { - assert(matchOpX == nullptr); - matchOpX = src_ops[k]; - } - } - assert(matchOpX != nullptr); - return matchOpX; -} - -std::vector - create_rule_graph(GraphXfer &xfer, - std::vector const &ops, - std::function const &get_input_tensor, - std::vector *const src_ops, - int parallel_degree) { - std::vector rule_graph; - - for (int i = 0; i < ops.size(); i++) { - sl::Operator const &op = ops[i]; - std::array inputs; - std::fill(inputs.begin(), inputs.end(), TensorX::NO_TX); - - for (int j = 0; j < op.input.size(); j++) { - int opId = op.input[j].opId; - int tsId = op.input[j].tsId; - if (opId < 0) { - inputs[j] = get_input_tensor(opId, tsId); - } else { - inputs[j] = rule_graph[opId]->outputs[tsId]; - } - } - - // We need the matched OpX for constructing conv2d/pool2d/linear - OpX *opx = nullptr; - switch (ops[i].op_type) { - case OP_CONV2D: { - OpX *matchOpX = src_ops == nullptr - ? nullptr - : find_opx_with_type(*src_ops, ops[i].op_type); - opx = xfer.create_conv2d(inputs[0], matchOpX); - break; - } - case OP_POOL2D: { - OpX *matchOpX = src_ops == nullptr - ? nullptr - : find_opx_with_type(*src_ops, ops[i].op_type); - opx = xfer.create_pool2d(inputs[0], matchOpX); - break; - } - default: - opx = create_opx(ops[i], - parallel_degree, - inputs[0], - inputs[1], - inputs[2], - inputs[3]); - } - rule_graph.push_back(opx); - } - - return rule_graph; -} - -void create_xfer(GraphXfer &xfer, sl::Rule const &r, int parallel_degree) { - std::unordered_map, TensorX> input_tensors; - std::function get_input_tensor = - [&xfer, &input_tensors](int opId, int tsId) -> TensorX { - if (input_tensors.find({opId, tsId}) == input_tensors.end()) { - input_tensors[{opId, tsId}] = xfer.new_tensor(); - } - return input_tensors.at({opId, tsId}); - }; - - xfer.srcOps = create_rule_graph( - xfer, r.srcOp, get_input_tensor, nullptr, parallel_degree); - xfer.dstOps = create_rule_graph( - xfer, r.dstOp, get_input_tensor, &xfer.srcOps, parallel_degree); - xfer.name = r.name; - if (xfer.srcOps.size() == 1) { - printf("Here!\n"); - } - - for (sl::MapOutput const &m : r.mappedOutput) { - TensorX srcTensorX = xfer.srcOps[m.srcOpId]->outputs[m.srcTsId]; - TensorX dstTensorX = xfer.dstOps[m.dstOpId]->outputs[m.dstTsId]; - xfer.map_output(srcTensorX, dstTensorX); - } -} - -bool check_opxes_have_same_type_and_constraints(OpX const &src_opx, - OpX const &dst_opx) { - if (src_opx.type != dst_opx.type) { - return false; - } - if (src_opx.pmConstraints.size() != dst_opx.pmConstraints.size()) { - return false; - } - if (src_opx.tnConstraints.size() != dst_opx.tnConstraints.size()) { - return false; - } - for (auto const &c1 : src_opx.pmConstraints) { - bool found_same = false; - for (auto const &c2 : dst_opx.pmConstraints) { - if (c1.comp == c2.comp && c1.para == c2.para && c1.value == c2.value) { - found_same = true; - } - } - if (!found_same) { - return false; - } - } - for (auto const &c1 : src_opx.tnConstraints) { - bool found_same = false; - for (auto const &c2 : dst_opx.tnConstraints) { - if (c1.singlePara && c2.singlePara) { - if (c1.comp == c2.comp && c1.para1 == c2.para1 && c1.dim1 == c2.dim1 && - c1.value == c2.value) { - found_same = true; - } - } else if ((!c1.singlePara) && (!c2.singlePara)) { - if (c1.comp == c2.comp && c1.para1 == c2.para1 && - c1.para2 == c2.para2 && c1.dim1 == c2.dim1 && c1.dim2 == c2.dim2) { - found_same = true; - } - } - } - if (!found_same) { - return false; - } - } - - return true; -} - -std::vector create_xfers(FFModel *model, - sl::RuleCollection const &rules, - int parallel_degree) { - std::vector xfers; - for (sl::Rule const &r : rules.rules) { - GraphXfer *xfer = new GraphXfer(model); - create_xfer(*xfer, r, parallel_degree); - if (xfer->srcOps.size() == 1 && xfer->dstOps.size() == 1) { - delete xfer; - continue; - } - // Pruning redundant xfer - bool found_same_xfer = false; - for (auto const &old_xfer : xfers) { - bool same = true; - if (old_xfer->srcOps.size() != xfer->srcOps.size()) { - same = false; - continue; - } - for (size_t i = 0; i < old_xfer->srcOps.size(); i++) { - if (!check_opxes_have_same_type_and_constraints(*old_xfer->srcOps[i], - *xfer->srcOps[i])) { - same = false; - } - } - if (!same) { - continue; - } - if (old_xfer->dstOps.size() != xfer->dstOps.size()) { - same = false; - continue; - } - for (size_t i = 0; i < old_xfer->dstOps.size(); i++) { - if (!check_opxes_have_same_type_and_constraints(*old_xfer->dstOps[i], - *xfer->dstOps[i])) { - same = false; - } - } - if (same) { - found_same_xfer = true; - break; - } - } - if (!found_same_xfer && xfer->srcOps.size() == 1) { - xfers.push_back(xfer); - } else { - delete (xfer); - } - } - return xfers; -} - -GraphSearchHelper::GraphSearchHelper(FFModel *model) - : model(model), config(model->config), mem_config(1.0) { - this->logger = std::unique_ptr(new RecursiveLogger("gs")); - generate_all_pcg_xfers(); -} - -void GraphSearchHelper::clear_cache() { - cached_optimized_graphs.clear(); -} - -void GraphSearchHelper::load_graph_substitutions( - std::vector &xfers) const { - xfers = all_pcg_xfers; -} - -void GraphSearchHelper::generate_all_pcg_xfers() { - std::vector all_parallel_degrees, single_node_parallel_degrees; - auto const &config = this->model->config; - int workersPerNode = - config.search_num_workers.value_or(config.workersPerNode); - int numNodes = config.search_num_nodes.value_or(config.numNodes); - log_xfers.debug() << "Generating parallel degrees for workersPerNode " - << workersPerNode << " and numNodes " << numNodes; - for (int i = 2; i <= workersPerNode; i++) { - if (workersPerNode % i == 0) { - single_node_parallel_degrees.push_back(i); - all_parallel_degrees.push_back(i); - } - } - for (int i = 2; i <= numNodes; i++) { - if (numNodes % i == 0) { - all_parallel_degrees.push_back(i * workersPerNode); - } - } - { - std::ostringstream oss; - oss << "Generating all_pcg_xfers for all parallel degrees: "; - for (int parallel_degree : all_parallel_degrees) { - oss << parallel_degree << " "; - } - - log_xfers.debug() << oss.str(); - } - - for (auto const &it : single_node_parallel_degrees) { - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_NONE, false)); - if (16 % it == 0) { - all_pcg_xfers.push_back( - create_replicate_attention_reduce(this->model, 16 /*num_heads*/, it)); - } - } - for (auto const &it : all_parallel_degrees) { - all_pcg_xfers.push_back( - create_partition_attention_combine(this->model, 16 /*num_heads*/, it)); - } - - if (config.substitution_json_path.has_value()) { - // Currently only consider a subset of all_parallel_degrees - std::vector considered_parallel_degrees; - considered_parallel_degrees.push_back(workersPerNode); - if (numNodes > 1) { - considered_parallel_degrees.push_back(numNodes * workersPerNode); - } - sl::RuleCollection rule_collection = sl::load_rule_collection_from_path( - config.substitution_json_path.value()); - for (int degree : considered_parallel_degrees) { - std::vector xfers = - create_xfers(this->model, rule_collection, degree); - all_pcg_xfers.insert(all_pcg_xfers.end(), xfers.begin(), xfers.end()); - } - } else { - // Manual substitutions - for (int num_dims = 3; num_dims <= 4; num_dims++) { - all_pcg_xfers.push_back( - create_linear_relu_merge(this->model, num_dims, true)); - all_pcg_xfers.push_back( - create_linear_relu_merge(this->model, num_dims, false)); - } - for (int const degree : all_parallel_degrees) { - create_mapping_xfers(this->model, degree, all_pcg_xfers); - create_mapping_xfers(this->model, degree, all_pcg_xfers); - create_mapping_xfers(this->model, degree, all_pcg_xfers); - } - for (auto const &it : all_parallel_degrees) { - // rewrites for the inception model - for (int i = 3; i <= 6; i++) { - all_pcg_xfers.push_back(create_combine_inception( - this->model, i - 1 /*num_convs*/, 5 /*num_dims*/, it)); - all_pcg_xfers.push_back(create_combine_concat( - this->model, i /*num_inputs*/, 5 /*num_dims*/, it)); - } - // all_pcg_xfers.push_back(create_partition_conv2d_combine(this->model, - // 5/*num_dims*/, it)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_NONE, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_NONE, false)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 1 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 2 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 3 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 4 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_relu_combine( - this->model, 3 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_relu_combine( - this->model, 4 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back( - create_partition_softmax_combine(this->model, - 0 /*softmax_dim*/, - 1 /*parallel_dims*/, - it /*num_parts*/)); - for (int num_combines = 1; num_combines < 5; num_combines++) { - all_pcg_xfers.push_back(leading_relu_branch_combine( - this->model, 3 /*parallel_dim*/, it /*num_parts*/, num_combines)); - all_pcg_xfers.push_back(leading_relu_branch_partition( - this->model, 3 /*parallel_dim*/, it /*num_parts*/, num_combines)); - } - { - std::unordered_set concat_num_inputs; - for (size_t i = 0; i < this->model->operators.size(); i++) { - if (this->model->operators[i]->op_type == OP_CONCAT) { - concat_num_inputs.insert(this->model->operators[i]->numInputs); - } - } - for (auto const &it2 : concat_num_inputs) { - all_pcg_xfers.push_back( - create_partition_concat_combine(this->model, - it2 /*num_inputs*/, - 0 /*concat_dim*/, - 1 /*parallel_dims*/, - it /*num_parts*/)); - all_pcg_xfers.push_back( - create_partition_concat_combine(this->model, - it2 /*num_inputs*/, - 2 /*concat_dim*/, - 3 /*parallel_dims*/, - it /*num_parts*/)); - } - } - } - } -} - -Graph *GraphSearchHelper::construct_graph() { - Graph *graph = new Graph(this->model); - std::unordered_map op_to_node_map; - for (FlexFlow::Op const *dstOp : this->model->operators) { - Node dstNode; - dstNode.ptr = dstOp; - dstNode.guid = this->model->node_global_guid++; - op_to_node_map[dstOp] = dstNode; - for (int j = 0; j < dstOp->numInputs; j++) { - FlexFlow::Op const *srcOp = dstOp->inputs[j]->owner_op; - assert(op_to_node_map.find(srcOp) != op_to_node_map.end()); - Node srcNode = op_to_node_map[srcOp]; - graph->add_edge(srcNode, dstNode, dstOp->inputs[j]->owner_idx, j); - } - } - - return graph; -} - -/** - * @brief Unity search algorithm main entrance. - * - * @param[in] budget Not used - * @param[in] only_data_parallel Not used - * @param[out] best_graph The best possible PCG after optimization - * @param[out] optimal_views The corresponding device placement views of the - * best graph - */ -void GraphSearchHelper::graph_optimize( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views) { - // Construct graph structure - this->logger->debug() << "Starting graph optimization"; - - Graph *graph = this->construct_graph(); - graph->duplicate_input_nodes(); - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - Node sink_node = graph->find_sink_node(); - GraphOptimizeResult optimal = - this->generic_sequence_optimize( - graph, - sink_node, - tl::nullopt /*output_shape*/, - tl::nullopt /*input_shape*/); - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal cost: " << optimal.cost << std::endl; - SimplificationSettings settings; - settings.fuse_parallel_ops = true; - settings.remove_noops = true; - settings.remove_trailing_parallel_ops = true; - settings.simplify_parallel_ops = true; - best_graph = std::unique_ptr(new Graph(optimal.graph.value())); - best_graph->simplify(settings); - std::unordered_map duplicated_optimal_views = - best_graph->optimal_views(); - std::unordered_map deduplication_map = - best_graph->deduplicate_input_nodes(); - std::unordered_map real_optimal_views; - for (auto const &kv : duplicated_optimal_views) { - if (deduplication_map.find(kv.first) != deduplication_map.end()) { - real_optimal_views[deduplication_map.at(kv.first)] = kv.second; - } else { - real_optimal_views[kv.first] = kv.second; - } - } - best_graph->print_strategy_computation_graph(optimal.views); - optimal_views = real_optimal_views; -} - -/** - * @brief Experimental DP algorithm to optimize PCG with the consideration of - * memory usage. This is to avoid polluting the current Unity search algorithm - * above. And this should be merged to GraphSearchHelper::graph_optimize - * eventually. - * - * @param[in] budget Not used - * @param[in] only_data_parallel Not used - * @param[out] best_graph The best possible PCG after optimization - * @param[out] optimal_views The corresponding device placement views of the - * best graph - * @param[out] search_result The performance result of the search - */ -void GraphSearchHelper::graph_optimize_with_memory( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views, - MemorySearchResult &search_result) { - this->logger->debug() - << "Starting graph optimization with memory consideration"; - - // Construct graph structure - Graph *graph = this->construct_graph(); - - // The input nodes may need to be duplicated because the PCG was constructed - // to have one input node for one input, but the actual execution graph should - // have the distributed version of inputs (i.e. multiple nodes). - graph->duplicate_input_nodes(); - - // Export an empty schedule if needed. - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - Node sink_node = graph->find_sink_node(); - - auto const start = std::chrono::system_clock::now(); - GraphOptimizeResultWithMemory optimal = - this->generic_sequence_optimize_with_memory< - GraphOptimizeResultWithMemory>( - graph, sink_node, tl::nullopt, tl::nullopt); - auto const end = std::chrono::system_clock::now(); - - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal run time cost: " << optimal.cost - << ", Memory usage: " << optimal.mem_cost - << " | run_time_cost_factor: " - << this->mem_config.run_time_cost_factor << std::endl; - - // Save the search performance results to the output argument - search_result.run_time_cost = optimal.cost; - search_result.memory_cost = optimal.mem_cost.num; - search_result.search_time = - std::chrono::duration_cast(end - start) - .count(); - - // Further simplify the "optimal" graph/schedule to have a more efficient - // graph and more accurate cost. - best_graph = std::unique_ptr(new Graph(optimal.graph.value())); - SimplificationSettings settings; - // Simplify to consider parallel op fusion - settings.fuse_parallel_ops = true; - settings.remove_noops = true; - settings.remove_trailing_parallel_ops = true; - settings.simplify_parallel_ops = true; - best_graph->simplify(settings); - - // Get the real optimal machine views. - std::unordered_map duplicated_optimal_views = - best_graph->optimal_views(); - std::unordered_map deduplication_map = - best_graph->deduplicate_input_nodes(); - std::unordered_map real_optimal_views; - for (auto const &kv : duplicated_optimal_views) { - if (deduplication_map.find(kv.first) != deduplication_map.end()) { - real_optimal_views[deduplication_map.at(kv.first)] = kv.second; - } else { - real_optimal_views[kv.first] = kv.second; - } - } - std::cout << "Dot graph of searched strategy:" << std::endl; - best_graph->print_strategy_computation_graph(optimal.views); - std::cout << std::endl; - - optimal_views = real_optimal_views; -} - -void GraphSearchHelper::graph_optimize_no_split( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views) { - // Construct graph structure - this->logger->debug() << "Starting graph optimization without split"; - - Graph *graph = this->construct_graph(); - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - SimplificationSettings settings; - settings.simplify_parallel_ops = true; - best_graph = this->base_optimize(graph, settings); - optimal_views = best_graph->optimal_views(); - - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal cost: " << best_graph->optimal_cost() << std::endl; -} - -static void graph_log_representation(Graph const *graph, - RecursiveLogger &logger) { - using FlexFlow::PCG::Utils::topo_sort; - - std::vector topo_sorted; - topo_sort(*graph, &topo_sorted); - std::ostringstream oss; - for (Node const &n : topo_sorted) { - logger.spew() << n.to_string(); - } -} - -void GraphSearchHelper::update_mem_optim_config( - MemoryOptimConfig const &new_config) { - mem_config = new_config; -} - -void GraphSearchHelper::find_rewrite_matches( - Graph const *graph, std::vector &matches) const { - std::vector xfers; - this->load_graph_substitutions(xfers); - - for (GraphXfer *xfer : xfers) { - log_xfer_matches.debug() - << "Finding matches for xfer: " << xfer->get_name(); - xfer->find_matches(graph, matches); - } - log_xfer_matches.debug() << "Finished finding xfer matches"; -} - -tl::optional - GraphSearchHelper::find_split_node(Graph const *graph, - int base_optimize_threshold) const { - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::MultisourceGraphStructure; - using FlexFlow::PCG::Utils::nodes; - using FlexFlow::PCG::Utils::post_dominators; - using FlexFlow::PCG::Utils::roots; - - TAG_ENTER(this->logger); - - int graph_size = nodes(*graph).size(); - this->logger->debug() << "Finding split node for graph (size " << graph_size - << ") with threshold " << base_optimize_threshold; - - if (graph_size <= base_optimize_threshold) { - this->logger->debug() - << "Graph size underneath threshold. Returning nullopt"; - return tl::nullopt; - } - - std::vector edges = get_edges(*graph); - std::unordered_map edge_scores; - - for (Edge const &e : edges) { - edge_scores[e] = 0; - } - - std::vector matches; - this->find_rewrite_matches(graph, matches); - this->logger->debug() << "Found " << matches.size() << " rewrite matches"; - { - TAG_ENTER(this->logger); - for (GraphXferMatch const &match : matches) { - auto msg = this->logger->spew(); - msg << match.get_xfer()->get_name() << " : "; - std::unordered_set nodes = match.get_nodes(); - for (Node const &node : nodes) { - msg << node.to_string() << " "; - } - } - } - - for (GraphXferMatch const &match : matches) { - for (Edge const &e : edges) { - if (match.containsEdge(graph, e)) { - edge_scores[e]++; - } - } - } - - this->logger->debug() << "Edge weights: "; - - { - TAG_ENTER(this->logger); - for (Edge const &e : edges) { - this->logger->debug() << e.srcOp.to_string() << "/" << e.srcIdx << " -> " - << e.dstOp.to_string() << "/" << e.dstIdx << " : " - << edge_scores.at(e); - } - } - - std::unordered_map> post_dominator_map = - post_dominators>(*graph); - Node source_node; - { - std::unordered_set source_nodes = roots(*graph); - if (source_nodes.size() != 1) { - source_nodes = roots>(*graph); - } - assert(source_nodes.size() == 1); - source_node = *source_nodes.begin(); - } - std::unordered_set possible_bottlenecks = - post_dominator_map.at(source_node); - Node sink_node = graph->find_sink_node(); - - int best_weight = 0; - tl::optional best = tl::nullopt; - int best_size = graph_size; - { - TAG_ENTER(this->logger); - - for (Node const &possible_bottleneck : possible_bottlenecks) { - if (possible_bottleneck == sink_node || - possible_bottleneck == source_node) { - continue; - } - - int weight = 0; - for (Edge const &e : graph->outEdges.at(possible_bottleneck)) { - weight += edge_scores.at(e); - } - this->logger->debug() - << "Potential bottleneck node " << possible_bottleneck.to_string() - << " has weight " << weight; - if (weight < best_weight) { - best_weight = weight; - best = possible_bottleneck; - } else if (weight == best_weight) { - // break ties by trying to choosing the split that produces the - // pre_graph with size closest to the threshold, favoring everything - // with smaller size over everything with larger size - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(possible_bottleneck); - int current_size = nodes(*pre_graph).size(); - - bool best_is_under = best_size <= base_optimize_threshold; - bool current_is_under = current_size <= base_optimize_threshold; - - bool condition1 = current_is_under && !best_is_under; - bool condition2 = - current_is_under && best_is_under && current_size > best_size; - bool condition3 = - !current_is_under && !best_is_under && current_size < best_size; - - if (condition1 || condition2 || condition3) { - best_weight = weight; - best = possible_bottleneck; - best_size = current_size; - } - } - } - } - - return best; -} - -/** - * @brief Base case of Unity's DP search algorithm. - * - * @param r_graph Graph to be optimized - * @param simplification_settings Settings to simplify the PCG - * @return std::unique_ptr Optimized PCG - */ -std::unique_ptr GraphSearchHelper::base_optimize( - Graph const *r_graph, - SimplificationSettings const &simplification_settings) { - // Construct graph substitutions - TAG_ENTER(this->logger); - - this->logger->debug() << "Optimizing base graph: "; - { - TAG_ENTER(this->logger); - /* graph_log_representation(r_graph, *this->logger); */ - // r_graph->print_dot(); - } - this->logger->debug() << "Starting cost: " << r_graph->optimal_cost(); - - std::vector xfers; - this->load_graph_substitutions(xfers); - - Graph *graph = new Graph(*r_graph); - - std::priority_queue, GraphCompare> candidates; - std::unordered_set hashmap; - candidates.push(graph); - hashmap.insert(graph->hash()); - Graph *best_graph = new Graph(*graph); - float best_cost = best_graph->optimal_cost(); - int counter = 0; - float const alpha = this->model->config.search_alpha; - - int budget = model->config.search_budget; - if (budget == 0) { - log_xfers.warning() - << "Base search budget is set to 0. This is probably not what you want " - "(use the --budget flag to set the base search budget)"; - } - for (int iter = 0; iter < budget || budget == -1; iter++) { - log_xfers.spew() << "Considering " << candidates.size() << " candidates"; - if (candidates.empty()) { - break; - } - - Graph *cur_graph = candidates.top(); - candidates.pop(); - if (cur_graph->optimal_cost() < best_graph->optimal_cost()) { - delete best_graph; - best_graph = cur_graph; - best_cost = cur_graph->optimal_cost(); - } else if (cur_graph->optimal_cost() > best_cost * alpha) { - continue; - } - - log_xfers.info("[%d] cur_cost(%.4lf) best_cost(%.4lf) candidates.size(%zu)", - counter, - cur_graph->optimal_cost(), - best_cost, - candidates.size()); - - log_xfers.debug() << "Considering " << xfers.size() << " possible xfers"; - for (size_t i = 0; i < xfers.size(); i++) { - int num_matches_found = 0, num_matches_rejected = 0; - log_xfers.debug() << "Considering xfer: " << xfers[i]->get_name(); - xfers[i]->run(0, - cur_graph, - candidates, - hashmap, - best_cost * alpha, - 1000, - simplification_settings, - num_matches_found, - num_matches_rejected); - log_xfers.debug() << "Rejected [ " << num_matches_rejected << " / " - << num_matches_found << " ] matches"; - /* std::cout << "." << std::flush; */ - } - /* std::cout << std::endl; */ - if (best_graph != cur_graph) { - delete cur_graph; - } - } - - this->logger->debug() << "Optimized cost: " << best_graph->optimal_cost(); - // best_graph->print_dot(); - return std::unique_ptr(best_graph); -} - -/** - * @brief Experimental. Base case of Unity's DP search algorithm with - * memory consideration. - * - * @param r_graph Graph to be optimized - * @param simplification_settings Settings to simplify the resulting PCG - * @return std::unique_ptr Optimized PCG - */ -std::unique_ptr GraphSearchHelper::base_optimize_with_memory( - Graph const *r_graph, - SimplificationSettings const &simplification_settings) { - TAG_ENTER(this->logger); - this->logger->debug() << "Optimizing base graph with memory: "; - { - TAG_ENTER(this->logger); - /* graph_log_representation(r_graph, *this->logger); */ - // r_graph->print_dot(); - } - this->logger->debug() << "Starting cost: " - << r_graph->optimal_cost_with_memory( - mem_config.run_time_cost_factor); - - // Construct graph substitutions - std::vector xfers; - this->load_graph_substitutions(xfers); - - // Prepare for the search - std::priority_queue, GraphCompareWithMemory> - candidates(GraphCompareWithMemory{mem_config.run_time_cost_factor}); - std::unordered_set hashmap; - - Graph *graph = new Graph(*r_graph); - candidates.push(graph); - hashmap.insert(graph->hash()); - - Graph *best_graph = new Graph(*graph); - float best_cost = - best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - - int counter = 0; - float const alpha = this->model->config.search_alpha; - int budget = model->config.search_budget; - if (budget == 0) { - log_xfers.warning() - << "Base search budget is set to 0. This is probably not what you want " - "(use the --budget flag to set the base search budget)"; - } - - // Actual exploration - for (int iter = 0; iter < budget || budget == -1; iter++) { - log_xfers.spew() << "Considering " << candidates.size() - << " candidates in base_optimize_with_memory"; - if (candidates.empty()) { - break; - } - - Graph *cur_graph = candidates.top(); - candidates.pop(); - if (cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor) < - best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor)) { - delete best_graph; - best_graph = cur_graph; - best_cost = - cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - } else if (cur_graph->optimal_cost_with_memory( - mem_config.run_time_cost_factor) > best_cost * alpha) { - continue; - } - - log_xfers.info( - "[%d] cur_cost(%.4lf) best_cost(%.4lf) candidates.size(%zu)", - counter, - cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor), - best_cost, - candidates.size()); - - log_xfers.debug() << "Considering " << xfers.size() - << " possible xfers in base_optimize_with_memory"; - for (size_t i = 0; i < xfers.size(); i++) { - int num_matches_found = 0, num_matches_rejected = 0; - log_xfers.debug() << "Considering xfer: " << xfers[i]->get_name(); - xfers[i]->run(0, - cur_graph, - candidates, - hashmap, - best_cost * alpha, - 1000, - simplification_settings, - num_matches_found, - num_matches_rejected); - log_xfers.debug() << "Rejected [ " << num_matches_rejected << " / " - << num_matches_found << " ] matches"; - } - - if (best_graph != cur_graph) { - delete cur_graph; - } - } - - this->logger->debug() - << "Optimized cost at the end of base_optimize_with_memory: " - << best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - - return std::unique_ptr(best_graph); -} - -size_t gs_dp_state_hash(Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - size_t key = graph->hash(); - hash_combine(key, sink_node.ptr); - hash_combine(key, output_shape); - hash_combine(key, input_shape); - return key; -} - -float GraphSearchHelper::sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - return this->generic_sequence_optimize( - graph, sink_node, output_shape, input_shape); -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache(size_t hash) const { - if (this->cached_optimized_graphs.find(hash) == - this->cached_optimized_graphs.end()) { - return tl::nullopt; - } else { - return this->cached_optimized_graphs.at(hash); - } -} - -template <> -float GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - return optimized->generic_optimal_cost(); -} - -template <> -GraphCostResult GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - return optimized->generic_optimal_cost(); -} - -template <> -GraphOptimizeResult GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - GraphOptimizeResult result; - result.graph = *optimized; - GraphCostResult gcr = optimized->generic_optimal_cost(); - result.cost = gcr.cost; - result.views = gcr.views; - return result; -} - -template <> -GraphOptimizeResultWithMemory - GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - GraphOptimizeResultWithMemory result; - result.graph = *optimized; - GraphCostResultWithMemory gcr = - optimized->generic_optimal_cost(); - result.cost = gcr.cost; - result.views = gcr.views; - result.mem_cost = gcr.mem_cost; - return result; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -void GraphSearchHelper::try_cache_result(size_t hash, - float const &value) { - this->cached_optimized_graphs[hash] = value; -} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphCostResult const &value) {} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphOptimizeResult const &value) {} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphOptimizeResultWithMemory const &value) {} - -/** - * @brief Get the cost/result of PCG if sequentially split it. - * - * @details This function is to combine the search results from DP sub-problems. - * The sub-problems are solved by generic_sequence_optimize(). - */ -template -T GraphSearchHelper::execute_sequence_split( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape) { - return sequence_cost( - this->generic_sequence_optimize( - pre_graph.get(), bottleneck, bottleneck_output_shape, input_shape), - this->generic_sequence_optimize( - post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); -} - -/** - * @brief Experimental. Consider memory usage when spliting the PCG during the - * DP search. This should be merged with execute_sequence_split(). - */ -template -T GraphSearchHelper::execute_sequence_split_with_memory( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape) { - return sequence_cost( - this->generic_sequence_optimize_with_memory( - pre_graph.get(), bottleneck, bottleneck_output_shape, input_shape), - this->generic_sequence_optimize_with_memory( - post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); -} - -/** - * @brief Top level DP search procedure for Unity. - */ -template -T GraphSearchHelper::generic_sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - /* int starting_depth = this->logger->get_depth(); */ - - TAG_ENTER(this->logger); - - size_t hash = gs_dp_state_hash(graph, sink_node, output_shape, input_shape); - tl::optional cached = this->try_get_cost_from_cache(hash); - if (cached.has_value()) { - this->logger->spew() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->spew() << "Retrieved value from cache: " << cached.value(); - } - - /* this->logger->check_same_as(starting_depth); */ - return cached.value(); - } - - this->logger->debug() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - T return_value; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->debug() << "Graph hash: " << std::setw(32) - << std::setfill('0') << graph->hash(); - if (input_shape.has_value()) { - this->logger->debug() << "Input shape: " << input_shape.value(); - } else { - this->logger->debug() << "Input shape: "; - } - if (output_shape.has_value()) { - this->logger->debug() << "Output shape: " << output_shape.value(); - } else { - this->logger->debug() << "Output shape: "; - } - - tl::optional bottleneck = - this->find_split_node(graph, this->config.base_optimize_threshold); - - if (!bottleneck.has_value()) { - this->logger->debug() << "Applying base case"; - Graph to_optimize(*graph); - if (input_shape.has_value()) { - Node input_node = - this->model->get_or_create_input_node(input_shape.value()); - Node noop_node = - this->model->get_or_create_noop_node(input_node.ptr->outputs[0]); - Graph input_graph(this->model); - Edge e(input_node, noop_node, 0, 0); - input_graph.add_edge(e); - - Node old_source_node = graph->find_source_node(); - ParallelTensorShape old_source_output_shape = - old_source_node.ptr->outputs[0]->get_shape(); - input_graph.reshape_output_tensor(old_source_output_shape); - - Node new_sink_node = input_graph.find_sink_node(); - assert(new_sink_node.ptr->numOutputs == 1); - assert(new_sink_node.ptr->outputs[0]->get_shape() == - old_source_output_shape); - - to_optimize.replace_subgraph({old_source_node}, input_graph); - } - SimplificationSettings settings; - if (output_shape.has_value()) { - to_optimize.reshape_output_tensor(output_shape.value()); - Node sink_node = to_optimize.find_sink_node(); - Node noop_node = - this->model->get_or_create_noop_node(sink_node.ptr->outputs[0]); - to_optimize.add_edge(sink_node, noop_node, 0, 0); - } else { - settings.remove_trailing_parallel_ops = true; - } - settings.simplify_parallel_ops = true; - std::unique_ptr optimized = - this->base_optimize(&to_optimize, settings); - return_value = get_optimal_cost( - std::move(optimized)); // optimized->generic_optimal_cost(); - } else { - this->logger->debug() << "Applying recursive case on bottleneck " - << bottleneck.value().guid; - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(bottleneck.value()); - - MachineResource resources(this->model->config); - std::vector valid_machine_views = - this->model->search->get_valid_machine_views(bottleneck.value().ptr, - resources); - - float best_cost = std::numeric_limits::infinity(); - tl::optional best_shape = tl::nullopt; - { - TAG_ENTER(this->logger); - for (ParallelTensorShape const &bottleneck_output_shape : - this->possible_split_output_tensor_shapes(bottleneck.value())) { - this->logger->debug() - << "Considering boundary shape " << bottleneck_output_shape; - float current_cost; - { - TAG_ENTER(this->logger); - // TODO @lockshaw we really should create the merged graph here - // since it's possible though unlikely for there to be hidden - // transfer costs between modules due to device assignment changes - // across the boundaries - - // We wait to add the communication nodes between boundaries so we - // don't accidentally split on them and keep processing the pure - // computation graph The bottleneck node is kept in the postgraph - // purely as a placeholder and will be replaced with an Input/NoOp - // sequence before any rewrites are actually performed - // this->logger->debug() << "Finding cost of pre_graph (" << - // bottleneck_output_shape << ")"; float pre_cost = - // this->generic_sequence_optimize(pre_graph.get(), - // bottleneck.value(), bottleneck_output_shape, input_shape); - // this->logger->debug() << "Cost of pre_graph (" << - // bottleneck_output_shape << "): " << pre_cost; - // this->logger->debug() << "Finding cost of post_graph (" << - // bottleneck_output_shape << ")"; float post_cost = - // this->generic_sequence_optimize(post_graph.get(), - // sink_node, output_shape, bottleneck_output_shape); - // this->logger->debug() << "Cost of post_graph (" << - // bottleneck_output_shape << "): " << post_cost; float current_cost - // = pre_cost + post_cost; - current_cost = - this->execute_sequence_split(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - bottleneck_output_shape); - - if (current_cost < best_cost) { - best_cost = current_cost; - best_shape = bottleneck_output_shape; - } - } - this->logger->debug() << "Boundary shape " << bottleneck_output_shape - << " has cost: " << current_cost; - } - } - - if (best_shape.has_value()) { - this->logger->debug() - << "Best intermediate shape found: " << best_shape.value(); - } else { - this->logger->debug() << "No valid intermediate shapes found"; - } - - if (best_cost != std::numeric_limits::infinity()) { - return_value = this->execute_sequence_split(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - best_shape.value()); - } - } - - this->try_cache_result(hash, return_value); - } - return return_value; -} - -/** - * @brief Top level DP search procedure for Unity with the consideration of - * memory usage. - * - * @tparam T Returned type - * @param graph Pre-optimization PCG - * @param sink_node Sink node of the PCG - * @param output_shape ??? - * @param input_shape ??? - * @return T Optimal result - */ -template -T GraphSearchHelper::generic_sequence_optimize_with_memory( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - TAG_ENTER(this->logger); - - // Try to find the result from cache first. But this will only get the cached - // result if the returned type is float. The float number means the best run - // time cost with only machine quantity (without distinguishing machine - // identities). - size_t hash = gs_dp_state_hash(graph, sink_node, output_shape, input_shape); - tl::optional cached = this->try_get_cost_from_cache(hash); - if (cached.has_value()) { - this->logger->spew() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->spew() << "Retrieved value from cache: " << cached.value(); - } - return cached.value(); - } - - // Couldn't find the result from cache. Try to optimize and get one. - this->logger->debug() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - T return_value; - { - // Print out debug information - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->debug() << "Graph hash: " << std::setw(32) - << std::setfill('0') << graph->hash(); - if (input_shape.has_value()) { - this->logger->debug() << "Input shape: " << input_shape.value(); - } else { - this->logger->debug() << "Input shape: "; - } - if (output_shape.has_value()) { - this->logger->debug() << "Output shape: " << output_shape.value(); - } else { - this->logger->debug() << "Output shape: "; - } - - // Find the node to sequentially split the PCG. - // Decide if the search reaches the base condition by this. - tl::optional bottleneck = - this->find_split_node(graph, this->config.base_optimize_threshold); - - if (!bottleneck.has_value()) { - this->logger->debug() << "Applying base case"; - - // Construct the PCG to optimize based on input_shape and output_shape - // information. - Graph to_optimize(*graph); - if (input_shape.has_value()) { - Node input_node = - this->model->get_or_create_input_node(input_shape.value()); - Node noop_node = - this->model->get_or_create_noop_node(input_node.ptr->outputs[0]); - Graph input_graph(this->model); - Edge e(input_node, noop_node, 0, 0); - input_graph.add_edge(e); - - Node old_source_node = graph->find_source_node(); - ParallelTensorShape old_source_output_shape = - old_source_node.ptr->outputs[0]->get_shape(); - input_graph.reshape_output_tensor(old_source_output_shape); - - Node new_sink_node = input_graph.find_sink_node(); - assert(new_sink_node.ptr->numOutputs == 1); - assert(new_sink_node.ptr->outputs[0]->get_shape() == - old_source_output_shape); - - to_optimize.replace_subgraph({old_source_node}, input_graph); - } - SimplificationSettings settings; - if (output_shape.has_value()) { - to_optimize.reshape_output_tensor(output_shape.value()); - Node sink_node = to_optimize.find_sink_node(); - Node noop_node = - this->model->get_or_create_noop_node(sink_node.ptr->outputs[0]); - to_optimize.add_edge(sink_node, noop_node, 0, 0); - } else { - settings.remove_trailing_parallel_ops = true; - } - settings.simplify_parallel_ops = true; - - // Call base optimization to perform graph substitution. - std::unique_ptr optimized = - this->base_optimize_with_memory(&to_optimize, settings); - return_value = get_optimal_cost(std::move(optimized)); - } else { - this->logger->debug() << "Applying recursive case on bottleneck " - << bottleneck.value().guid; - - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(bottleneck.value()); - - MachineResource resources(this->model->config); - std::vector valid_machine_views = - this->model->search->get_valid_machine_views(bottleneck.value().ptr, - resources); - - // Try to find the best cost and corresponding best bottleneck shape. - // This search process is based on the float version of - // execute_sequence_split_with_memory(). - float best_cost = std::numeric_limits::infinity(); - tl::optional best_shape = tl::nullopt; - { - TAG_ENTER(this->logger); - for (auto const &bottleneck_output_shape : - this->possible_split_output_tensor_shapes(bottleneck.value())) { - this->logger->debug() - << "Considering boundary shape " << bottleneck_output_shape; - float current_cost; - { - TAG_ENTER(this->logger); - // Get the cost from execute_sequence_split_with_memory by - // only changing bottleneck_output_shape. - current_cost = this->execute_sequence_split_with_memory( - pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - bottleneck_output_shape); - - if (current_cost < best_cost) { - best_cost = current_cost; - best_shape = bottleneck_output_shape; - } - } - this->logger->debug() << "Boundary shape " << bottleneck_output_shape - << " has cost: " << current_cost; - } - } - - if (best_shape.has_value()) { - this->logger->debug() - << "Best intermediate shape found: " << best_shape.value(); - } else { - this->logger->debug() << "No valid intermediate shapes found"; - } - - // ? What if best_cost is infinity ? - if (best_cost != std::numeric_limits::infinity()) { - // Get the return value of correct type with previously found - // best_shape. - return_value = - this->execute_sequence_split_with_memory(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - best_shape.value()); - } - } - // Try to cache the float result - this->try_cache_result(hash, return_value); - } - return return_value; -} - -std::vector - GraphSearchHelper::possible_split_output_tensor_shapes( - Node const &source_node) const { - TAG_ENTER(this->logger); - - this->logger->debug() << "Finding possible output tensor shapes for node " - << source_node.guid; - assert(source_node.ptr->numOutputs == 1); - ParallelTensor output_tensor = source_node.ptr->outputs[0]; - for (int i = 0; i < output_tensor->num_dims; i++) { - assert(output_tensor->dims[i].degree == 1); - } - - std::vector without_replicas; - - int num_devices = this->config.numNodes * this->config.workersPerNode; - int degrees[MAX_TENSOR_DIM]; - std::fill_n(degrees, MAX_TENSOR_DIM, 1); - - ParallelTensorShape base_shape; - base_shape.num_dims = output_tensor->num_dims; - for (int i = 0; i < output_tensor->num_dims; i++) { - base_shape.dims[i].degree = 1; - base_shape.dims[i].size = output_tensor->dims[i].size; - } - without_replicas.push_back(base_shape); - - { - TAG_ENTER(this->logger); - while (true) { - bool is_done = true; - for (int i = 0; i < output_tensor->num_dims; i++) { - degrees[i] *= 2; - if (degrees[i] > num_devices) { - degrees[i] = 1; - } else { - is_done = false; - break; - } - } - std::ostringstream oss; - for (int i = 0; i < output_tensor->num_dims; i++) { - oss << degrees[i] << " "; - } - this->logger->spew() << "Considering: " << oss.str(); - if (is_done) { - break; - } - - bool is_valid = true; - int total_degree = 1; - ParallelTensorShape shape; - shape.num_dims = output_tensor->num_dims; - for (int i = 0; i < output_tensor->num_dims; i++) { - total_degree *= degrees[i]; - shape.dims[i].degree = degrees[i]; - shape.dims[i].size = output_tensor->dims[i].size; - if (shape.dims[i].size % shape.dims[i].degree != 0) { - is_valid = false; - } - } - if (total_degree <= num_devices && is_valid) { - without_replicas.push_back(shape); - } - } - } - - this->logger->debug() << "Found " << without_replicas.size() - << " possible tensor output shapes without replicas"; - this->logger->debug() << "They are:"; - { - TAG_ENTER(this->logger); - for (auto const &shape : without_replicas) { - this->logger->debug() << shape; - } - } - return without_replicas; -} - -void GraphSearchHelper::subgraph_optimize(Graph *subgraph) {} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - return this->create_conv2d(input, matchOpX); -} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - OpX *pool = new OpX(OP_POOL2D, 1, 1, input); - pool->matchOpX = matchOpX; - return pool; -} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - OpX *flat = new OpX(OP_FLAT, 1, 1, input); - flat->matchOpX = matchOpX; - return flat; -} - -GraphXfer *create_partition_linear_combine(FFModel *model, - int num_dims, - int num_parts, - ActiMode activation, - bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *linear1 = subst->create_linear( - input, NULL /*matchOpX*/, num_dims, activation, use_bias); - OpX *repartition = subst->create_repartition(input, num_dims - 2, num_parts); - OpX *linear2 = subst->create_linear(repartition->outputs[0], - linear1 /*matchOpX*/, - num_dims, - activation, - use_bias); - OpX *combine = - subst->create_combine(linear2->outputs[0], num_dims - 2, num_parts); - subst->map_output(linear1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(linear1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(linear2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_conv2d_combine(FFModel *model, - int num_dims, - int num_parts) { - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *conv1 = subst->create_conv2d(input, NULL /*matchOpX*/); - OpX *repartition = subst->create_repartition(input, num_dims - 2, num_parts); - OpX *conv2 = - subst->create_conv2d(repartition->outputs[0], conv1 /*matchOpX*/); - OpX *combine = - subst->create_combine(conv2->outputs[0], num_dims - 2, num_parts); - subst->map_output(conv1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(conv1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(conv2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_conv2d_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_inception(FFModel *model, - int num_convs, - int num_dims, - int num_parts) { - // 3 convs and 1 pool2d - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *src_combine = subst->create_combine(input, num_dims - 2, num_parts); - subst->srcOps.push_back(src_combine); - std::vector src_convs; - for (int i = 0; i < num_convs; i++) { - OpX *conv = - subst->create_conv2d(src_combine->outputs[0], NULL /*matchOpX*/); - src_convs.push_back(conv); - subst->srcOps.push_back(conv); - } - OpX *src_pool = - subst->create_pool2d(src_combine->outputs[0], NULL /*matchOpX*/); - subst->srcOps.push_back(src_pool); - // dst ops - std::vector dst_convs; - for (int i = 0; i < num_convs; i++) { - OpX *conv = subst->create_conv2d(input, src_convs[i] /*matchOpX*/); - OpX *comb = - subst->create_combine(conv->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(conv); - subst->dstOps.push_back(comb); - subst->map_output(src_convs[i]->outputs[0], comb->outputs[0]); - } - OpX *dst_pool = subst->create_pool2d(input, src_pool /*matchOpX*/); - OpX *dst_comb = - subst->create_combine(dst_pool->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(dst_pool); - subst->dstOps.push_back(dst_comb); - subst->map_output(src_pool->outputs[0], dst_comb->outputs[0]); - subst->name = "create_combine_inceptionA"; - return subst; -} - -GraphXfer *create_combine_concat(FFModel *model, - int num_inputs, - int num_dims, - int num_parts) { - // assert 5D - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - std::vector inputs, concat_inputs; - std::vector combines; - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(subst->new_tensor()); - combines.push_back( - subst->create_combine(inputs[i], num_dims - 2, num_parts)); - concat_inputs.push_back(combines[i]->outputs[0]); - subst->srcOps.push_back(combines[i]); - } - OpX *concat1 = subst->create_concat( - concat_inputs.data(), num_inputs, NULL /*matchOpX*/, 2); - subst->srcOps.push_back(concat1); - OpX *concat2 = - subst->create_concat(inputs.data(), num_inputs, concat1 /*matchOpX*/, 2); - OpX *combine = - subst->create_combine(concat2->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(concat2); - subst->dstOps.push_back(combine); - subst->map_output(concat1->outputs[0], combine->outputs[0]); - subst->name = "create_combine_concat"; - return subst; -} - -GraphXfer *create_partition_attention_combine(FFModel *model, - int num_heads, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *attn1 = subst->create_attention( - input, input, input, NULL /*matchOpX*/, num_heads); - OpX *repart = subst->create_repartition(input, 2, num_parts); - OpX *attn2 = subst->create_attention(repart->outputs[0], - repart->outputs[0], - repart->outputs[0], - attn1 /*matchOpX*/, - num_heads); - OpX *combine = subst->create_combine(attn2->outputs[0], 2, num_parts); - subst->map_output(attn1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(attn1); - subst->dstOps.push_back(repart); - subst->dstOps.push_back(attn2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_attention_combine[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_replicate_attention_reduce(FFModel *model, - int num_heads, - int num_parts) { - assert(num_heads % num_parts == 0); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *attn1 = subst->create_attention( - input, input, input, NULL /*matchOpX*/, num_heads); - OpX *repl = subst->create_replicate(input, 3, num_parts); - OpX *attn2 = subst->create_attention(repl->outputs[0], - repl->outputs[0], - repl->outputs[0], - attn1 /*matchOpX*/, - num_heads / num_parts); - OpX *reduce = subst->create_reduction(attn2->outputs[0], 3, num_parts); - subst->map_output(attn1->outputs[0], reduce->outputs[0]); - subst->srcOps.push_back(attn1); - subst->dstOps.push_back(repl); - subst->dstOps.push_back(attn2); - subst->dstOps.push_back(reduce); - - std::ostringstream oss; - oss << "replicate_attention_reduce[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_replicate_linear_combine(FFModel *model, - int num_dims, - int num_parts, - ActiMode activation, - bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *linear1 = subst->create_linear( - input, NULL /*matchOpX*/, num_dims, activation, use_bias); - OpX *replicate = subst->create_replicate(input, num_dims - 1, num_parts); - OpX *linear2 = subst->create_linear(replicate->outputs[0], - linear1 /*matchOpX*/, - num_dims, - activation, - use_bias); - OpX *combine = subst->create_combine(linear2->outputs[0], 0, num_parts); - subst->map_output(linear1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(linear1); - subst->dstOps.push_back(replicate); - subst->dstOps.push_back(linear2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "replicate_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_add_combine(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input1 = subst->new_tensor(); - TensorX input2 = subst->new_tensor(); - OpX *add1 = subst->create_element_binary(input1, input2, OP_EW_ADD); - OpX *repartition1 = - subst->create_repartition(input1, parallel_dim, num_parts); - OpX *repartition2 = - subst->create_repartition(input2, parallel_dim, num_parts); - OpX *add2 = subst->create_element_binary( - repartition1->outputs[0], repartition2->outputs[0], OP_EW_ADD); - OpX *combine = - subst->create_combine(add2->outputs[0], parallel_dim, num_parts); - subst->map_output(add1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(add1); - subst->dstOps.push_back(repartition1); - subst->dstOps.push_back(repartition2); - subst->dstOps.push_back(add2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_add_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_add_partition(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input1 = subst->new_tensor(); - TensorX input2 = subst->new_tensor(); - OpX *add1 = subst->create_element_binary(input1, input2, OP_EW_ADD); - - OpX *combine1 = subst->create_combine(input1, parallel_dim, num_parts); - OpX *combine2 = subst->create_combine(input2, parallel_dim, num_parts); - OpX *add2 = subst->create_element_binary( - combine1->outputs[0], combine2->outputs[0], OP_EW_ADD); - OpX *repartition = - subst->create_repartition(add2->outputs[0], parallel_dim, num_parts); - subst->map_output(add1->outputs[0], repartition->outputs[0]); - subst->srcOps.push_back(add1); - subst->dstOps.push_back(combine1); - subst->dstOps.push_back(combine2); - subst->dstOps.push_back(add2); - subst->dstOps.push_back(repartition); - - std::ostringstream oss; - oss << "combine_add_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_relu_combine(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *relu1 = subst->create_element_unary(input, OP_RELU); - - OpX *partition = subst->create_repartition(input, parallel_dim, num_parts); - OpX *relu2 = subst->create_element_unary(partition->outputs[0], OP_RELU); - OpX *combine = - subst->create_combine(relu2->outputs[0], parallel_dim, num_parts); - - subst->map_output(relu1->outputs[0], combine->outputs[0]); - - subst->srcOps.push_back(relu1); - - subst->dstOps.push_back(partition); - subst->dstOps.push_back(relu2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_relu_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_relu_partition(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *relu1 = subst->create_element_unary(input, OP_RELU); - - OpX *combine = subst->create_combine(input, parallel_dim, num_parts); - OpX *relu2 = subst->create_element_unary(combine->outputs[0], OP_RELU); - OpX *partition = - subst->create_repartition(relu2->outputs[0], parallel_dim, num_parts); - - subst->map_output(relu1->outputs[0], partition->outputs[0]); - - subst->srcOps.push_back(relu1); - - subst->dstOps.push_back(combine); - subst->dstOps.push_back(relu2); - subst->dstOps.push_back(partition); - - std::ostringstream oss; - oss << "combine_relu_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_concat_combine(FFModel *model, - int num_inputs, - int concat_dim, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - assert(num_inputs <= MAX_NUM_INPUTS); - TensorX inputs[MAX_NUM_INPUTS]; - for (int i = 0; i < num_inputs; i++) { - inputs[i] = subst->new_tensor(); - } - OpX *concat = - subst->create_concat(inputs, num_inputs, NULL /*matchOpX*/, concat_dim); - subst->srcOps.push_back(concat); - TensorX new_inputs[MAX_NUM_INPUTS]; - for (int i = 0; i < num_inputs; i++) { - OpX *repartition = - subst->create_repartition(inputs[i], parallel_dim, num_parts); - new_inputs[i] = repartition->outputs[0]; - subst->dstOps.push_back(repartition); - } - OpX *concat2 = subst->create_concat( - new_inputs, num_inputs, concat /*matchOpX*/, concat_dim); - subst->dstOps.push_back(concat2); - OpX *combine = - subst->create_combine(concat2->outputs[0], parallel_dim, num_parts); - subst->dstOps.push_back(combine); - subst->map_output(concat->outputs[0], combine->outputs[0]); - - std::ostringstream oss; - oss << "partition_concat_combine[" - << "num_inputs=" << num_inputs << ",concat_dim=" << concat_dim - << ",parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_softmax_combine(FFModel *model, - int softmax_dim, - int parallel_dim, - int num_parts) { - assert(parallel_dim != softmax_dim); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *softmax1 = subst->create_softmax(input, softmax_dim); - OpX *repartition = subst->create_repartition(input, parallel_dim, num_parts); - OpX *softmax2 = subst->create_softmax(repartition->outputs[0], softmax_dim); - OpX *combine = - subst->create_combine(softmax2->outputs[0], parallel_dim, num_parts); - subst->map_output(softmax1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(softmax1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(softmax2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_softmax_combine[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_softmax_partition(FFModel *model, - int softmax_dim, - int parallel_dim, - int num_parts) { - assert(parallel_dim != softmax_dim); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *softmax1 = subst->create_softmax(input, softmax_dim); - OpX *combine = subst->create_combine(input, parallel_dim, num_parts); - OpX *softmax2 = subst->create_softmax(combine->outputs[0], softmax_dim); - OpX *repartition = - subst->create_repartition(softmax2->outputs[0], parallel_dim, num_parts); - subst->map_output(softmax1->outputs[0], repartition->outputs[0]); - subst->srcOps.push_back(softmax1); - subst->dstOps.push_back(combine); - subst->dstOps.push_back(softmax2); - subst->dstOps.push_back(repartition); - - std::ostringstream oss; - oss << "combine_softmax_partition[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *leading_relu_branch_combine(FFModel *model, - int parallel_dim, - int num_parts, - int num_combines) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_partition = - subst->create_repartition(input, parallel_dim, num_parts); - std::vector old_combines; - for (int i = 0; i < num_combines; i++) { - old_combines.push_back( - subst->create_combine(input, parallel_dim, num_parts)); - } - - OpX *new_partition = - subst->create_repartition(input, parallel_dim, num_parts); - std::vector new_noops; - for (int i = 0; i < num_combines; i++) { - new_noops.push_back(subst->create_noop(input)); - } - - subst->map_output(old_partition->outputs[0], new_partition->outputs[0]); - for (int i = 0; i < num_combines; i++) { - subst->map_output(old_combines[i]->outputs[0], new_noops[i]->outputs[0]); - } - - subst->srcOps.push_back(old_partition); - subst->srcOps.insert( - subst->srcOps.end(), old_combines.begin(), old_combines.end()); - subst->dstOps.push_back(new_partition); - subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); - - std::ostringstream oss; - oss << "leading_relu_branch_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_combines=" << num_combines << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *leading_relu_branch_partition(FFModel *model, - int parallel_dim, - int num_parts, - int num_partitions) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_combine = subst->create_combine(input, parallel_dim, num_parts); - std::vector old_partitions; - for (int i = 0; i < num_partitions; i++) { - old_partitions.push_back( - subst->create_repartition(input, parallel_dim, num_parts)); - } - - OpX *new_combine = subst->create_combine(input, parallel_dim, num_parts); - std::vector new_noops; - for (int i = 0; i < num_partitions; i++) { - new_noops.push_back(subst->create_noop(input)); - } - - subst->map_output(old_combine->outputs[0], new_combine->outputs[0]); - for (int i = 0; i < num_partitions; i++) { - subst->map_output(old_partitions[i]->outputs[0], new_noops[i]->outputs[0]); - } - - subst->srcOps.push_back(old_combine); - subst->srcOps.insert( - subst->srcOps.end(), old_partitions.begin(), old_partitions.end()); - subst->dstOps.push_back(new_combine); - subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); - - std::ostringstream oss; - oss << "leading_relu_branch_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_partitions=" << num_partitions << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer * - create_linear_relu_merge(FFModel *model, int num_dims, bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_linear = - subst->create_linear(input, nullptr, num_dims, AC_MODE_NONE, use_bias); - OpX *old_relu = subst->create_relu(old_linear->outputs[0]); - - OpX *new_linear = - subst->create_linear(input, old_linear, num_dims, AC_MODE_RELU, use_bias); - - subst->map_output(old_relu->outputs[0], new_linear->outputs[0]); - subst->srcOps.push_back(old_linear); - subst->srcOps.push_back(old_relu); - subst->dstOps.push_back(new_linear); - - std::ostringstream oss; - oss << "linear_relu_merge[" - << "num_dims=" << num_dims << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -} // namespace ffc - -using PCG::Edge; -using PCG::Graph; -using PCG::Node; - -/** - * @brief Optimize the graph stored in FFModel. - * - * @param[in] budget The search budget - * @param[in] only_data_parallel True if only doing data parallel training - * @param[out] best_graph The searched best graph - * @param[out] optimal_views The corresponding machine view of the best_graph - * @param[in] perform_memory_search True if we want to consider memory during - * the search - * @param[in] new_config Memory optimization config to use if this is a memory - * search - * @param[out] search_result The performance result of this search - */ -void FFModel::graph_optimize( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views, - bool perform_memory_search, - MemoryOptimConfig new_config, - MemorySearchResult &search_result) { - if (perform_memory_search) { - this->graph_search->update_mem_optim_config(new_config); - this->graph_search->graph_optimize_with_memory( - budget, only_data_parallel, best_graph, optimal_views, search_result); - } else { - this->graph_search->graph_optimize( - budget, only_data_parallel, best_graph, optimal_views); - } -} - -bool FFModel::convert_graph_to_operators( - Graph const *graph, - std::unordered_map const &optimal_views) { - // Clear operators - operators.clear(); - std::unordered_map todos; - std::unordered_map node_to_op; - std::vector queue; - for (auto const &it : graph->inEdges) { - auto const &inList = it.second; - if (inList.size() == 0) { - queue.push_back(it.first); - } else { - todos[it.first] = (int)inList.size(); - } - } - size_t index = 0; - while (index < queue.size()) { - Node node = queue[index++]; - assert(node.ptr != NULL); - auto const &inList = graph->inEdges.find(node)->second; - ParallelTensor inputs[MAX_NUM_INPUTS]; - int num_inputs = 0; - for (auto const &e : inList) { - inputs[e.dstIdx] = node_to_op[e.srcOp]->outputs[e.srcIdx]; - assert(e.dstIdx < (int)inList.size()); - num_inputs++; - } - Op *new_op = NULL; - switch (node.ptr->op_type) { - case OP_INPUT: { - NoOp *noop = (NoOp *)node.ptr; - new_op = new NoOp( - *this, OP_INPUT, noop->input_tensor_guid, node.ptr->outputs[0]); - break; - } - case OP_CONCAT: { - Concat *concat = (Concat *)node.ptr; - new_op = new Concat( - *this, (int)inList.size(), inputs, concat->legion_axis, NULL); - break; - } - case OP_AGGREGATE: { - Aggregate *aggr = (Aggregate *)node.ptr; - new_op = new Aggregate(*this, inputs, aggr->n, aggr->lambda_bal, NULL); - break; - } - case OP_SPLIT: { - Split *split = (Split *)node.ptr; - std::vector splits; - for (int i = 0; i < split->numOutputs; i++) { - splits.push_back(split->outputs[i]->dims[split->legion_axis].size); - } - new_op = new Split(*this, inputs[0], splits, split->legion_axis, NULL); - break; - } - case OP_EMBEDDING: { - new_op = new Embedding(*this, *(Embedding *)node.ptr, inputs[0], true); - break; - } - case OP_EW_ADD: - case OP_EW_SUB: - case OP_EW_MUL: - case OP_EW_MAX: - case OP_EW_MIN: { - assert(inList.size() == 2); - ElementBinary *eb = (ElementBinary *)node.ptr; - new_op = new ElementBinary( - *this, eb->op_type, inputs[0], inputs[1], eb->inplace_a, NULL); - break; - } - case OP_POOL2D: { - new_op = new Pool2D(*this, *(Pool2D *)node.ptr, inputs[0]); - break; - } - case OP_CONV2D: { - new_op = new Conv2D(*this, *(Conv2D *)node.ptr, inputs[0], true); - break; - } - case OP_DROPOUT: { - new_op = new Dropout(*this, *(Dropout *)node.ptr, inputs[0]); - break; - } - case OP_LINEAR: { - new_op = new Linear(*this, *(Linear *)node.ptr, inputs[0], true); - break; - } - case OP_MULTIHEAD_ATTENTION: { - assert(inList.size() == 3); - MultiHeadAttention *attn = (MultiHeadAttention *)node.ptr; - new_op = new MultiHeadAttention( - *this, *attn, inputs[0], inputs[1], inputs[2], true); - break; - break; - } - case OP_SOFTMAX: { - assert(inList.size() == 1); - Softmax *softmax = (Softmax *)node.ptr; - new_op = new Softmax(*this, inputs[0], softmax->dim, NULL); - break; - } - case OP_COMBINE: { - assert(inList.size() == 1); - Combine *combine = (Combine *)node.ptr; - new_op = new Combine( - *this, inputs[0], combine->combine_dim, combine->combine_degree); - break; - } - case OP_REPARTITION: { - assert(inList.size() == 1); - Repartition *repart = (Repartition *)node.ptr; - new_op = new Repartition(*this, - inputs[0], - repart->repartition_dim, - repart->repartition_degree); - break; - } - case OP_REPLICATE: { - assert(inList.size() == 1); - Replicate *replicate = (Replicate *)node.ptr; - new_op = new Replicate(*this, - inputs[0], - replicate->replicate_dim, - replicate->replicate_degree); - break; - } - case OP_REDUCTION: { - assert(inList.size() == 1); - Reduction *reduction = (Reduction *)node.ptr; - new_op = new Reduction(*this, - inputs[0], - reduction->reduction_dim, - reduction->reduction_degree); - break; - } - case OP_FUSED_PARALLEL: { - assert(inList.size() == 1); - FusedParallelOp *fused = (FusedParallelOp *)node.ptr; - std::vector parallel_ops; - for (int i = 0; i < fused->num_parallel_ops; i++) { - parallel_ops.push_back(fused->parallel_ops[i]); - } - new_op = new FusedParallelOp(*this, inputs[0], parallel_ops); - break; - } - default: { - new_op = node.ptr->materialize(*this, inputs, num_inputs); - break; - } - } - // Set machine view for the output tensors of this operator - assert(optimal_views.find(node) != optimal_views.end()); - MachineView view = optimal_views.find(node)->second; - for (int i = 0; i < new_op->numOutputs; i++) { - new_op->outputs[i]->machine_view = view; - } - // Set machine view for the weight tensors of this operator - for (int i = 0; i < new_op->numWeights; i++) { - new_op->weights[i]->machine_view = view; - } - node_to_op[node] = new_op; - operators.push_back(new_op); - // Decrease the todos - auto const &outList = graph->outEdges.find(node)->second; - for (auto const &it : outList) { - todos[it.dstOp] -= 1; - if (todos[it.dstOp] == 0) { - queue.push_back(it.dstOp); - } - } - } - assert(queue.size() == graph->inEdges.size()); - // Remove the final parallel operators - while (operators[operators.size() - 1]->is_parallel_op()) { - Op *op = operators[operators.size() - 1]; - if (op->op_type == OP_REDUCTION) { - break; - } - if (op->op_type == OP_FUSED_PARALLEL) { - FusedParallelOp *fused_op = (FusedParallelOp *)op; - bool has_reduction = false; - for (int i = 0; i < fused_op->num_parallel_ops; i++) { - if (fused_op->parallel_ops[i].op_type == OP_REDUCTION) { - has_reduction = true; - } - } - if (has_reduction) { - break; - } - } - operators.pop_back(); - } - return true; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/old/substitution.h b/lib/compiler/src/old/substitution.h deleted file mode 100644 index 95a59e952c..0000000000 --- a/lib/compiler/src/old/substitution.h +++ /dev/null @@ -1,309 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_SUBSTITUTION_H_ -#define _FLEXFLOW_SUBSTITUTION_H_ -#include "graph.h" -#include "substitutions/substitutions.h" -#include "tl/optional.hpp" -#include "utils/recursive_logger.h" -#include -#include - -namespace FlexFlow { -namespace ffc { - -/* struct PMConstraint { */ -/* PMConstraint(Compare comp, PMParameter para, int value); */ -/* Compare comp; */ -/* PMParameter para; */ -/* int value; */ -/* }; */ - -struct TNConstraint { - TNConstraint(Compare comp, TNParameter para, DIMParameter dim, int value); - TNConstraint(Compare comp, - TNParameter para1, - DIMParameter dim1, - TNParameter para2, - DIMParameter dim2); - bool singlePara; - Compare comp; - TNParameter para1, para2; - DIMParameter dim1, dim2; - int value; -}; - -/* class Op; */ -/* class OpX; */ -/* class GraphXfer; */ - -struct TensorX { - static const TensorX NO_TX; - TensorX(void) : op(NULL), idx(0) {} - TensorX(OpX *_op, int _idx) : op(_op), idx(_idx) {} - tl::optional - to_tensor(GraphXfer const *xfer) const; - OpX *op; - int idx; - - bool operator==(TensorX const &other) const; - bool operator!=(TensorX const &other) const; -}; - -struct TensorXCompare { - bool operator()(TensorX const &a, TensorX const &b) const { - if (a.op != b.op) { - return a.op < b.op; - } - return a.idx < b.idx; - }; -}; - -/* class OpX { */ -/* public: */ -/* OpX(OperatorType type, */ -/* int numInputs, */ -/* int numOutputs, */ -/* TensorX const &input1 = TensorX::NO_TX, */ -/* TensorX const &input2 = TensorX::NO_TX, */ -/* TensorX const &input3 = TensorX::NO_TX, */ -/* TensorX const &input4 = TensorX::NO_TX); */ -/* OpX(OperatorType type, */ -/* int num_inputs, */ -/* int num_outputs, */ -/* TensorX const *inputs); */ -/* bool add_pm_constraint(Compare, PMParameter para, int value); */ -/* bool add_input_constraint(Compare, TNParameter, DIMParameter, int); */ -/* bool add_input_constraint( */ -/* Compare, TNParameter, DIMParameter, TNParameter, DIMParameter); */ -/* bool get_pm_constraint(PMParameter para, int &value) const; */ - -/* public: */ -/* OperatorType type; */ -/* Node mapOp; */ -/* OpX const *matchOpX; */ -/* std::vector inputs, weights, outputs; */ -/* std::vector pmConstraints; */ -/* std::vector tnConstraints; */ -/* }; */ - -OpX *create_opx(substitutions::Operator const &op, - int parallel_degree, - TensorX const &input1 = TensorX::NO_TX, - TensorX const &input2 = TensorX::NO_TX, - TensorX const &input3 = TensorX::NO_TX, - TensorX const &input4 = TensorX::NO_TX); -void create_xfer(GraphXfer &xfer, - substitutions::Rule const &r, - int parallel_degree); -std::vector - create_xfers(substitutions::RuleCollection const &rules, - int parallel_degree); - -class GraphCompare { -public: - bool operator()(Graph *lhs, Graph *rhs) { - return lhs->optimal_cost() > rhs->optimal_cost(); - } -}; - -class GraphXferMatch { -public: - GraphXferMatch(GraphXfer const *); - - void add_mapping(Node const &, OpX *); - void add_mapping(OpX *, Node const &); - void add_input_mapping(int, std::pair const &); - void add_output_mapping(TensorX const &, TensorX const &); - OpX *at(Node const &) const; - Node at(OpX *) const; - void set_graph(Graph const *); - - bool containsNode(Graph const *, Node const &) const; - bool containsEdge(Graph const *, Edge const &) const; - - GraphXfer const *get_xfer() const; - std::unordered_set get_nodes() const; - -private: - std::map nodeToOpX; - std::map opXToNode; - std::map mappedOutputs; - size_t graph_hash; - GraphXfer const *xfer; -}; - -/* class GraphXfer { */ -/* public: */ -/* GraphXfer(); */ -/* TensorX new_tensor(void); */ -/* bool can_match(OpX *srcOp, Node const &op, Graph const *graph); */ -/* void match(OpX *srcOp, Node const &op, Graph const *graph); */ -/* void unmatch(OpX *srcOp, Node const &op, Graph const *graph); */ -/* // Compute Ops */ -/* template */ -/* OpX *create_opx(TensorX const &input, OpX const *matchOpX); */ - -/* OpX *create_noop(TensorX const &input); */ -/* OpX *create_concat(TensorX const *inputs, */ -/* int num_inputs, */ -/* OpX const *match_opx, */ -/* int concat_dim); */ -/* OpX *create_element_binary(TensorX const &input1, */ -/* TensorX const &input2, */ -/* OperatorType op_type); */ -/* OpX *create_element_unary(TensorX const &input, OperatorType op_type); */ -/* OpX *create_relu(TensorX const &input); */ -/* OpX *create_linear(TensorX const &input, */ -/* OpX const *match_opx, */ -/* int num_dims, */ -/* ActiMode acti_mode, */ -/* bool use_bias); */ -/* OpX *create_conv2d(TensorX const &input, OpX const *match_opx); */ -/* OpX *create_pool2d(TensorX const &input, OpX const *match_opx); */ -/* OpX *create_attention(TensorX const &query, */ -/* TensorX const &key, */ -/* TensorX const &value, */ -/* OpX const *match_opx, */ -/* int num_heads); */ -/* OpX *create_softmax(TensorX const &input, int softmax_dim); */ -/* // Parallel Ops */ -/* OpX *create_repartition(TensorX const &input, */ -/* int repartition_dim, */ -/* int num_parts); */ -/* OpX *create_replicate(TensorX const &input, int replicate_dim, int - * num_parts); */ -/* OpX *create_reduction(TensorX const &input, int reduction_dim, int - * num_parts); */ -/* OpX *create_combine(TensorX const &input, int combine_dim, int num_parts); - */ -/* bool map_output(TensorX const &src, TensorX const &dst); */ - -/* Graph *create_new_graph(Graph const *graph, */ -/* SimplificationSettings const &settings); */ -/* bool create_new_operator(OpX const *opx, Node &op); */ - -/* std::string get_name() const; */ - -/* void run(int depth, */ -/* Graph *graph, */ -/* std::priority_queue, GraphCompare> - * &, */ -/* std::unordered_set &, */ -/* float threshold, */ -/* int maxNumOps, */ -/* SimplificationSettings const &simplification_settings, */ -/* int &num_matches_found, */ -/* int &num_matches_rejected); */ - -/* void find_matches(Graph const *, std::vector &matches); */ -/* GraphXferMatch get_match_record(Graph const *) const; */ - -/* private: */ -/* void find_matches(int depth, */ -/* Graph const *graph, */ -/* std::vector &matches); */ - -/* public: */ -/* tl::optional name = tl::nullopt; */ -/* int tensorId; */ -/* std::map mappedOps; */ -/* std::multimap> mappedInputs; */ -/* std::map mappedOutputs; */ -/* std::vector srcOps; */ -/* std::vector dstOps; */ -/* }; */ - -struct SubstitutionMatch { - std::unordered_map node_assignment; - std::unordered_map edge_assignment; -}; - -std::unordered_set - find_matches(SubstitutionPattern const &pattern, - ParallelComputationGraph const &pcg); - -class GraphSearchHelper { -public: - GraphSearchHelper(); - void graph_optimize(size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views); - void graph_optimize_no_split( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views); - -private: - template - T generic_sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape); - - float sequence_optimize(Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape); - - template - T execute_sequence_split( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape); - void generate_all_pcg_xfers(); - void load_graph_substitutions(std::vector &xfers) const; - Graph *construct_graph(); - void subgraph_optimize(Graph *subgraph); - - std::unique_ptr - base_optimize(Graph const *, - SimplificationSettings const &simplification_settings); - - std::vector - possible_split_output_tensor_shapes(Node const &) const; - - void find_rewrite_matches(Graph const *graph, - std::vector &matches) const; - tl::optional find_split_node(Graph const *graph, - int base_optimize_threshold) const; - - template - tl::optional try_get_cost_from_cache(size_t hash) const; - - template - void try_cache_result(size_t hash, T const &value); - - template - T get_optimal_cost(std::unique_ptr optimized) const; - -private: - std::unordered_map cached_optimized_graphs; - std::vector all_pcg_xfers; - std::unique_ptr logger; -}; - -} // namespace ffc -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 86fdd88d92..9d648ed99b 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -20,7 +20,7 @@ std::unordered_set Strategy graph_optimize(ComputationGraph &cg, - ICostEstimator const &cost_estimator, + CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( Operator const &, MachineSpecification const &)> const @@ -35,12 +35,8 @@ Strategy DeduplicatedPriorityQueue, StrategyRuntimeCmp> candidates; - Strategy initial_result(pcg, - optimal_cost(pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs)); + OptimalCostResult initial_pcg_result = optimal_cost(pcg, allowed_machine_views, cost_estimator, resources, cached_subgraph_costs); + Strategy initial_result{pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; Strategy best_result = initial_result; candidates.push(initial_result); @@ -50,7 +46,7 @@ Strategy Strategy const ¤t_result = candidates.top(); candidates.pop(); - if (StrategyRuntimeCmp(current_result, best_result)) { + if (StrategyRuntimeCmp{}(current_result, best_result)) { best_result = current_result; } else if (current_result.runtime > best_result.runtime * opt_config.alpha) { @@ -64,9 +60,9 @@ Strategy cost_estimator, resources, cached_subgraph_costs); - Strategy new_result(new_pcg, c.machine_mapping, c.runtime); + Strategy new_result{new_pcg, c.machine_mapping, c.runtime}; if (new_result.runtime <= opt_config.threshold && - new_result.pcg.query_nodes({}).size() <= opt_config.max_num_ops) { + get_nodes(new_pcg.value()).size() <= opt_config.max_num_ops) { candidates.push(new_result); } } diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 1a5c2bc3f8..b482e851d8 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -26,6 +26,8 @@ struct MachineView : public use_visitable_cmp { StridedRectangle rect; }; +FF_VISITABLE_STRUCT(MachineView, start, rect); + std::size_t num_dims(MachineView const &); std::size_t num_devices(MachineView const &); DeviceType get_device_type(MachineView const &); @@ -43,7 +45,4 @@ MachineView make_1d_machine_view(device_id_t start, size_t interval_size); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::MachineView, start, rect); -MAKE_VISIT_HASHABLE(::FlexFlow::MachineView); - #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 7e332933c7..2342cd08fa 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -15,6 +15,16 @@ struct ParallelComputationGraph }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); +bool operator==(ParallelComputationGraph const &, ParallelComputationGraph const &); + } // namespace FlexFlow +namespace std { + +template <> +struct hash { + size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; +}; +} + #endif diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index a52906c612..98471a8fbd 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -28,4 +28,12 @@ SubParallelComputationGraph } // namespace FlexFlow +namespace std{ +template <> +struct hash { + size_t operator()(FlexFlow::Substitution const &) const; +}; + +}; + #endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 5b2e5093bd..3a1444a0f5 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -108,6 +108,9 @@ std::unordered_set get_node_edges(UndirectedGraphView const &, std::unordered_set get_outputs(MultiDiGraphView const &); std::unordered_set get_inputs(MultiDiGraphView const &); +std::unordered_set get_open_outputs(OpenMultiDiGraphView const &); +std::unordered_set get_open_inputs(OpenMultiDiGraphView const &); + std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); std::unordered_set get_incoming_edges(DiGraphView const &, diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 15da6ce2cb..8bfcda9d0f 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -8,6 +8,7 @@ namespace FlexFlow { template struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { + INodeLabelledMultiDiGraphView() = default; INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; INodeLabelledMultiDiGraphView & operator=(INodeLabelledMultiDiGraphView const &) = delete; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 4d8c790400..2cbaaf44fd 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -9,8 +9,8 @@ template struct INodeLabelledOpenMultiDiGraphView : virtual INodeLabelledMultiDiGraphView, virtual IOpenMultiDiGraphView { - INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = - delete; + INodeLabelledOpenMultiDiGraphView() = default; + INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = delete; INodeLabelledOpenMultiDiGraphView & operator=(INodeLabelledOpenMultiDiGraphView const &) = delete; }; @@ -82,12 +82,12 @@ struct NodeLabelledOpenMultiDiGraph } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr()->query_nodes(); + return get_ptr()->query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdge const &q) const { - return get_ptr()->query_edges(); + return get_ptr()->query_edges(q); } Node add_node(NodeLabel const &l) { diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 4a4c81aef9..8c8a8b1a1b 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -26,6 +26,10 @@ struct OutputLabelledOpenMultiDiSubgraphView return g.at(n); } + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + return g.at(i); + } + EdgeLabel const &at(MultiDiOutput const &o) const override { return g.at(o); } @@ -39,11 +43,17 @@ struct OutputLabelledOpenMultiDiSubgraphView return SubgraphView(g, nodes).query_edges(q); } + OutputLabelledOpenMultiDiSubgraphView* clone() const override { + return new OutputLabelledOpenMultiDiSubgraphView(g, nodes); + } + private: OutputLabelledOpenMultiDiGraphView const &g; std::unordered_set const &nodes; }; +// CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index a9aa1e5251..9b35cdc883 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -15,6 +15,7 @@ struct IOutputLabelledMultiDiGraphView operator=(IOutputLabelledMultiDiGraphView const &) = delete; virtual OutputLabel const &at(MultiDiOutput const &) = 0; + using INodeLabelledMultiDiGraphView::at; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); @@ -119,10 +120,10 @@ struct OutputLabelledMultiDiGraph } std::unordered_set query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); + return get_ptr()->query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return get_ptr()->query_edges(q); } template diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index b3ecb5f273..28dba47bce 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN #define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN -#include "node_labelled.h" +#include "node_labelled_open.h" #include "utils/graph/adjacency_openmultidigraph.h" namespace FlexFlow { @@ -59,6 +59,7 @@ struct OutputLabelledOpenMultiDiGraphView protected: using NodeLabelledOpenMultiDiGraphView< NodeLabel>::NodeLabelledOpenMultiDiGraphView; + OutputLabelledOpenMultiDiGraphView(cow_ptr_t ptr) : GraphView(ptr) {} private: cow_ptr_t get_ptr() const { @@ -70,7 +71,7 @@ struct OutputLabelledOpenMultiDiGraphView template EdgeLabel at(OutputLabelledOpenMultiDiGraphView const &g, OpenMultiDiEdge const &e) { - return visit([&](auto const e) { return g.at(e); }, e); + return visit([&](auto const &e) { return g.at(e); }, e); } template @@ -173,6 +174,11 @@ struct OutputLabelledOpenMultiDiGraph cow_ptr_t ol; }; +template +void add_label(OutputLabelledOpenMultiDiGraph &g, OpenMultiDiEdge const &e, EdgeLabel const &l) { + visit([&](const auto &e) { g.add_label(e, l); }, e); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 0426a73e73..5a227d46ec 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -3,6 +3,7 @@ #include "node_labelled.h" #include "standard_labelled.h" +#include "output_labelled_open.h" namespace FlexFlow { @@ -70,7 +71,7 @@ CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled Impl materialize_output_labelled_multidigraph_view( - IOutputLabelledMultiDiGraphView const &g) { + OutputLabelledMultiDiGraphView const &g) { Impl result; for (Node const &n : get_nodes(g)) { result.add_node_unsafe(n); From c015efb50979f3f97ec81908868078296223968e Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 15 Nov 2023 17:12:53 -0500 Subject: [PATCH 02/37] unity dp works --- lib/compiler/CMakeLists.txt | 1 + lib/compiler/include/compiler/cost_estimate.h | 4 +- .../include/compiler/machine_mapping.h | 2 +- .../include/compiler/unity_algorithm.h | 3 +- lib/compiler/src/graph_utils.cc | 16 + lib/compiler/src/machine_mapping.cc | 26 +- lib/compiler/src/unity_algorithm.cc | 18 +- lib/compiler/test/CMakeLists.txt | 5 +- .../test/{ => src}/test_cost_estimator.h | 0 lib/compiler/test/{ => src}/test_generator.h | 2 +- .../test/src/test_labelled_open_graph.cc | 77 +++++ .../test/{ => src}/test_machine_mapping.cc | 2 +- lib/compiler/test/src/test_open_graph.cc | 80 +++++ lib/compiler/test/src/test_optimal_cost.cc | 60 ++++ .../test/{ => src}/test_unity_algorithm.cc | 0 lib/compiler/test/test_disjoint_set.cc | 19 -- lib/compiler/test/test_dominators.cc | 322 ------------------ lib/compiler/test/test_dot.cc | 23 -- lib/compiler/test/test_dp.cc | 54 --- lib/compiler/test/test_labelled_open_graph.cc | 76 ----- lib/compiler/test/test_machine_view.cc | 33 -- lib/compiler/test/test_open_graph.cc | 102 ------ lib/compiler/test/test_optimal_cost.cc | 24 -- lib/compiler/test/test_parallel_config.cc | 25 -- lib/compiler/test/test_random_utils.cc | 47 --- lib/compiler/test/test_substitution_loader.cc | 144 -------- lib/op-attrs/src/get_output_shapes.cc | 6 + lib/pcg/include/pcg/machine_specification.h | 19 +- lib/pcg/include/pcg/machine_view.h | 6 +- lib/pcg/include/pcg/operator.h | 3 +- lib/pcg/include/pcg/strided_rectangle.h | 23 +- lib/pcg/src/machine_view.cc | 4 +- lib/pcg/src/operator.cc | 2 +- lib/pcg/src/parallel_computation_graph.cc | 37 ++ lib/pcg/src/strided_rectangle.cc | 6 +- lib/utils/include/utils/graph/digraph.h | 2 +- .../utils/graph/labelled/node_labelled.h | 8 +- .../utils/graph/labelled/node_labelled_open.h | 7 +- .../include/utils/graph/labelled/open_views.h | 39 +++ .../utils/graph/labelled/output_labelled.h | 30 +- .../graph/labelled/output_labelled_open.h | 20 +- .../utils/graph/labelled/standard_labelled.h | 28 +- .../utils/graph/labelled/unordered_label.h | 3 +- .../include/utils/graph/labelled/views.h | 22 +- lib/utils/include/utils/graph/multidigraph.h | 2 +- lib/utils/include/utils/graph/open_graphs.h | 9 +- lib/utils/include/utils/graph/undirected.h | 2 +- lib/utils/src/graph/algorithms.cc | 35 +- lib/utils/src/graph/digraph.cc | 7 +- lib/utils/src/graph/multidigraph.cc | 7 +- lib/utils/src/graph/node.cc | 2 +- lib/utils/src/graph/open_graphs.cc | 27 +- lib/utils/src/graph/serialparallel.cc | 13 +- lib/utils/src/graph/undirected.cc | 6 +- lib/utils/src/graph/views.cc | 24 +- 55 files changed, 552 insertions(+), 1012 deletions(-) rename lib/compiler/test/{ => src}/test_cost_estimator.h (100%) rename lib/compiler/test/{ => src}/test_generator.h (98%) create mode 100644 lib/compiler/test/src/test_labelled_open_graph.cc rename lib/compiler/test/{ => src}/test_machine_mapping.cc (95%) create mode 100644 lib/compiler/test/src/test_open_graph.cc create mode 100644 lib/compiler/test/src/test_optimal_cost.cc rename lib/compiler/test/{ => src}/test_unity_algorithm.cc (100%) delete mode 100644 lib/compiler/test/test_disjoint_set.cc delete mode 100644 lib/compiler/test/test_dominators.cc delete mode 100644 lib/compiler/test/test_dot.cc delete mode 100644 lib/compiler/test/test_dp.cc delete mode 100644 lib/compiler/test/test_labelled_open_graph.cc delete mode 100644 lib/compiler/test/test_machine_view.cc delete mode 100644 lib/compiler/test/test_open_graph.cc delete mode 100644 lib/compiler/test/test_optimal_cost.cc delete mode 100644 lib/compiler/test/test_parallel_config.cc delete mode 100644 lib/compiler/test/test_random_utils.cc delete mode 100644 lib/compiler/test/test_substitution_loader.cc create mode 100644 lib/pcg/src/parallel_computation_graph.cc diff --git a/lib/compiler/CMakeLists.txt b/lib/compiler/CMakeLists.txt index 45c369fcdf..6610834eed 100644 --- a/lib/compiler/CMakeLists.txt +++ b/lib/compiler/CMakeLists.txt @@ -18,3 +18,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) \ No newline at end of file diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h index 27f963db50..3791292529 100644 --- a/lib/compiler/include/compiler/cost_estimate.h +++ b/lib/compiler/include/compiler/cost_estimate.h @@ -16,10 +16,11 @@ struct ICostEstimator { MachineView const &src, MachineView const &dst) const = 0; + ICostEstimator() = default; ICostEstimator(ICostEstimator const &) = delete; ICostEstimator &operator=(ICostEstimator const &) = delete; - virtual ~ICostEstimator(); + virtual ~ICostEstimator() = default; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); @@ -44,6 +45,7 @@ struct CostEstimator { } private: + CostEstimator(std::shared_ptr implementation_ptr) : implementation_ptr(implementation_ptr) {} std::shared_ptr implementation_ptr; }; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index e8d7457fbf..9f9d97937d 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -15,7 +15,7 @@ struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); static bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - + req> machine_views; }; FF_VISITABLE_STRUCT(MachineMapping, machine_views); diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index fc068d48c5..81e8375948 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -37,6 +37,7 @@ Strategy } // namespace FlexFlow VISITABLE_STRUCT(FlexFlow::Strategy, pcg, machine_mapping, runtime); + namespace std { template <> @@ -44,6 +45,6 @@ struct hash { size_t operator()(FlexFlow::Strategy const &) const; }; -}; +} #endif diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index d7f15e0796..04e96c66ed 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -7,6 +7,22 @@ SerialParallelDecomposition return get_serial_parallel_decomposition(pcg.value()); } +ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { + NOT_IMPLEMENTED(); +} + +SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { + auto g = pcg.value(); + auto g_ = view_output_labelled_as_output_labelled_open(g); + auto subpcg = materialize_output_labelled_open_multidigraph_view< + AdjacencyOpenMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling, + UnorderedLabelling + >(g_); + return subpcg; +} + std::vector get_sorted_node_input_edges(ParallelComputationGraph const &pcg, Node const &n) { diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index fb04f57eac..671c59a94f 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -95,7 +95,7 @@ float estimate_cost(SubParallelComputationGraphView const &g, CostEstimator const &estimator, MachineMapping const &device_mapping, std::unordered_map const &frontier_machine_views) { - NOT_IMPLEMENTED(); + return 0.1; } void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { @@ -122,8 +122,8 @@ struct OptimalCost { SubParallelComputationGraphView const &g; CostEstimator const &cost_estimator; MachineSpecification const &resource; - std::unordered_map const &given_machine_views; - std::unordered_map const &frontier_machine_views; + std::unordered_map given_machine_views; + std::unordered_map frontier_machine_views; std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views; @@ -138,7 +138,6 @@ struct OptimalCost { if (cached_result) { return cached_result.value(); } - OptimalCostResult result = this->optimal_cost(t); cached_subgraph_costs.save(state, result); @@ -161,14 +160,15 @@ struct OptimalCost { Node split_point = get_only(post_graph_sources); OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); - + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { - auto new_given_machine_views = merge_maps(given_machine_views, std::unordered_map{{split_point, mv}}); - auto new_frontier_machine_views = merge_maps(frontier_machine_views, - std::unordered_map{{split_edge, mv}}); + std::unordered_map new_given_machine_views = given_machine_views; + new_given_machine_views.emplace(split_point, mv); + std::unordered_map new_frontier_machine_views = frontier_machine_views; + new_frontier_machine_views.emplace(split_edge, mv); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( visit(OptimalCost(pre_graph, @@ -269,14 +269,16 @@ OptimalCostResult CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - return visit(OptimalCost(pcg_to_subpcg(g), + SerialParallelDecomposition sp_decomposition = get_serial_parallel_decomposition(g); + SubParallelComputationGraph subpcg = pcg_to_subpcg(g); + return visit(OptimalCost(subpcg, cost_estimator, resources, - {}, - {}, + std::unordered_map{}, + std::unordered_map{}, allowed_machine_views, cached_subgraph_costs), - get_serial_parallel_decomposition(g)); + sp_decomposition); } } // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 9d648ed99b..16671b080a 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -10,7 +10,9 @@ bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { } std::unordered_set - get_all_substitutions(ParallelComputationGraph const &pcg); + get_all_substitutions(ParallelComputationGraph const &pcg) { + NOT_IMPLEMENTED(); +} std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, @@ -73,3 +75,17 @@ Strategy } } // namespace FlexFlow + +namespace std { + +size_t hash::operator()(FlexFlow::Strategy const &s) const { + size_t h = 0; + + hash_combine(h, s.pcg); + // hash_combine(h, s.machine_mapping); + hash_combine(h, s.runtime); + + return h; +} + +} diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index dbbd0a63ec..cc64b15f7d 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,10 +2,13 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/*.cc + src/test_labelled_open_graph.cc + src/test_open_graph.cc + src/test_optimal_cost.cc PRIVATE_INCLUDE src/ DEPS + utils compiler doctest utils-test-common diff --git a/lib/compiler/test/test_cost_estimator.h b/lib/compiler/test/src/test_cost_estimator.h similarity index 100% rename from lib/compiler/test/test_cost_estimator.h rename to lib/compiler/test/src/test_cost_estimator.h diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/src/test_generator.h similarity index 98% rename from lib/compiler/test/test_generator.h rename to lib/compiler/test/src/test_generator.h index 374bb89455..23a79abbe0 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_TEST_GENERATOR_H #include "compiler/machine_mapping.h" -#include "compiler/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" #include "pcg/computation_graph.h" #include "rapidcheck.h" diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc new file mode 100644 index 0000000000..78ea1ece55 --- /dev/null +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -0,0 +1,77 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "rapidcheck.h" + +using namespace FlexFlow; + +TEST_CASE("get_subgraph_open_graph") { + auto g = OpenMultiDiGraph::create(); + + int t0 = 100000; + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); + Node n4 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + NodePort p2 = g.add_node_port(); + NodePort p3 = g.add_node_port(); + NodePort p4 = g.add_node_port(); + NodePort p5 = g.add_node_port(); + NodePort p6 = g.add_node_port(); + NodePort p7 = g.add_node_port(); + NodePort p8 = g.add_node_port(); + NodePort p9 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e1{n2, p2, n0, p0}; + MultiDiEdge e2{n3, p5, n1, p3}; + MultiDiEdge e3{n3, p6, n2, p4}; + MultiDiEdge e4{n4, p8, n3, p7}; + OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + std::unordered_set node_set0{n3, n4}; + + auto subgraph0 = get_subgraph(g, node_set0); + auto subgraph1 = get_subgraph(g, node_set0); + auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph3 = get_subgraph(g, node_set0); + + CHECK(get_nodes(subgraph0) == node_set0); + CHECK(get_nodes(subgraph1) == node_set0); + CHECK(get_nodes(subgraph2) == node_set0); + CHECK(get_nodes(subgraph3) == node_set0); + + std::unordered_set input_set{split_edge(e2).second, + split_edge(e3).second}; + std::unordered_set output_set{e5}; + + CHECK(bool(get_open_inputs(subgraph0) == input_set)); + CHECK(bool(get_open_inputs(subgraph1) == input_set)); + CHECK(bool(get_open_inputs(subgraph2).empty())); + CHECK(bool(get_open_inputs(subgraph3).empty())); + + CHECK(bool(get_open_outputs(subgraph0) == output_set)); + CHECK(bool(get_open_outputs(subgraph1).empty())); + CHECK(bool(get_open_outputs(subgraph2) == output_set)); + CHECK(bool(get_open_outputs(subgraph3).empty())); + + CHECK(bool(get_edges(subgraph0) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); + CHECK(bool(get_edges(subgraph1) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); + CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); +} diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc similarity index 95% rename from lib/compiler/test/test_machine_mapping.cc rename to lib/compiler/test/src/test_machine_mapping.cc index 4436a992d3..779f8134d9 100644 --- a/lib/compiler/test/test_machine_mapping.cc +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -1,4 +1,4 @@ -#include "doctest.h" +#include "doctest/doctest.h" #include "test_generator.h" TEST_CASE("MachineMapping::combine") { diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc new file mode 100644 index 0000000000..ea1108c291 --- /dev/null +++ b/lib/compiler/test/src/test_open_graph.cc @@ -0,0 +1,80 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "utils/graph/algorithms.h" + +using namespace FlexFlow; + +TEST_CASE("get_source_sink_open_graph") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + int s0 = 100000; + + Node n0 = g.add_node(); + NodePort p0 = g.add_node_port(); + InputMultiDiEdge e0{n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; + g.add_edge(e0); + + CHECK(bool(get_closed_sources(g) == std::unordered_set{})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{})); +} + +TEST_CASE("get_source_sink_open_graph:unconnected") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + int s0 = 100000; + int t0 = s0 + 1; + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; + g.add_edge(e0); + g.add_edge(e1); + + /* + g: ->n0 + n1-> + */ + + CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); +} + +TEST_CASE("get_cut") { + auto g = OpenMultiDiGraph::create(); + + std::vector ns; + for (int i = 0; i < 5; ++i) { + ns.push_back(g.add_node()); + } + + MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; + MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; + MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; + OutputMultiDiEdge e5{ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; + CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, e2})); + + GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; + CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, e4})); +} diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc new file mode 100644 index 0000000000..87f9d06342 --- /dev/null +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -0,0 +1,60 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "test_cost_estimator.h" + +using namespace FlexFlow; + +/* +Tests whether optimal_cost can give a valid result given random PCG, trivial +allowed machine views, trivial cost estimator and random machine specification. +*/ +// TEST_CASE("optimal_cost") { +// auto test_allowed_machine_views = [](Operator const &, +// MachineSpecification const &) { +// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; +// }; +// rc::check([](ParallelComputationGraph const &g, +// MachineSpecification const &machine_spec) { +// OptimalCostCache cached_subgraph_costs; +// OptimalCostResult result = optimal_cost(g, +// test_allowed_machine_views, +// TestCostEstimator{}, +// machine_spec, +// cached_subgraph_costs); +// RC_ASSERT(result.runtime > 0); +// RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); +// }); +// } + +TEST_CASE("optimal_cost_0") { + auto pcg = OutputLabelledMultiDiGraph::template create< + AdjacencyMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling + >(); + + Node n0 = pcg.add_node(Operator(InputAttrs{}, "input")); + Node n1 = pcg.add_node(Operator(LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, "linear")); + + MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; + pcg.add_edge(e); + pcg.add_output(e, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); + + auto test_allowed_machine_views = [](Operator const &, + MachineSpecification const &) { + return std::unordered_set{make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + }; + + CostEstimator estimator = CostEstimator::create(); + + MachineSpecification machine_spec{1, 1, 1, 1, 1}; + + OptimalCostCache cached_results; + + OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), test_allowed_machine_views, estimator, machine_spec, cached_results); + + CHECK(bool(result.runtime > 0)); +} \ No newline at end of file diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc similarity index 100% rename from lib/compiler/test/test_unity_algorithm.cc rename to lib/compiler/test/src/test_unity_algorithm.cc diff --git a/lib/compiler/test/test_disjoint_set.cc b/lib/compiler/test/test_disjoint_set.cc deleted file mode 100644 index 796605f53f..0000000000 --- a/lib/compiler/test/test_disjoint_set.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "flexflow/utils/disjoint_set.h" -#include "gtest/gtest.h" - -TEST(disjoint_set, basic) { - int ctr = 0; - int a = ctr++, b = ctr++, c = ctr++, d = ctr++, e = ctr++, f = ctr++; - - disjoint_set ds; - ds.m_union(a, b); - ds.m_union(b, c); - ds.m_union(e, f); - ds.m_union(d, a); - - assert(ds.find(a) == ds.find(b)); - assert(ds.find(a) == ds.find(c)); - assert(ds.find(a) == ds.find(d)); - assert(ds.find(e) == ds.find(f)); - assert(ds.find(e) != ds.find(a)); -} diff --git a/lib/compiler/test/test_dominators.cc b/lib/compiler/test/test_dominators.cc deleted file mode 100644 index 60ac33696f..0000000000 --- a/lib/compiler/test/test_dominators.cc +++ /dev/null @@ -1,322 +0,0 @@ -#include "flexflow/basic_graph.h" -#include "flexflow/dominators.h" -#include "flexflow/utils/hash-utils.h" -#include "gtest/gtest.h" - -using namespace FlexFlow::PCG::Utils; - -namespace FlexFlow::PCG::Utils { -template <> -struct invalid_node<::BasicGraph, GraphStructure<::BasicGraph>> { - int operator()() const { - return -1; - } -}; -} // namespace FlexFlow::PCG::Utils - -TEST(pred_succ_cessors, basic) { - BasicGraph g; - g.add_node(0); - g.add_node(1); - g.add_node(2); - g.add_node(3); - g.add_node(4); - - g.add_edge(0, 2); - g.add_edge(1, 2); - g.add_edge(2, 3); - g.add_edge(2, 4); - - using AnswerMap = std::unordered_map>; - - AnswerMap expected_predecessors; - - expected_predecessors = {{0, {}}, {1, {}}, {2, {0, 1}}, {3, {2}}, {4, {2}}}; - - AnswerMap expected_successors = { - {0, {2}}, {1, {2}}, {2, {3, 4}}, {3, {}}, {4, {}}}; - - std::unordered_set answer; - for (auto const &kv : expected_predecessors) { - answer.clear(); - predecessors>(g, kv.first, &answer); - EXPECT_EQ(kv.second, answer) - << "^^^ Predecessors for node " << kv.first << std::endl; - } - for (auto const &kv : expected_successors) { - answer.clear(); - successors>(g, kv.first, &answer); - EXPECT_EQ(kv.second, answer) - << "^^^ Successors for node " << kv.first << std::endl; - } -} - -TEST(topo_sort, basic) { - BasicGraph g; - g.add_nodes({0, 1, 2, 3}); - g.add_edges({{3, 1}, {3, 0}, {1, 0}, {0, 2}}); - - std::vector topo_answer = {3, 1, 0, 2}; - - std::vector topo_result; - topo_sort(g, &topo_result); - EXPECT_EQ(topo_result, topo_answer); -} - -BasicGraph get_dominator_test_graph() { - BasicGraph g; - g.add_nodes({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); - g.add_edges({{1, 2}, - {1, 7}, - {2, 3}, - {2, 4}, - {3, 6}, - {4, 5}, - {4, 6}, - {5, 6}, - {6, 8}, - {7, 8}, - {8, 9}, - {8, 10}, - {9, 11}, - {10, 11}}); - - return g; -} - -TEST(dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map> answer = {{1, {1}}, - {2, {1, 2}}, - {3, {1, 2, 3}}, - {4, {1, 2, 4}}, - {5, {1, 2, 4, 5}}, - {6, {1, 2, 6}}, - {7, {1, 7}}, - {8, {1, 8}}, - {9, {1, 8, 9}}, - {10, {1, 8, 10}}, - {11, {1, 8, 11}}}; - - EXPECT_EQ(dominators(g), answer); -} - -TEST(post_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map> answer = {{1, {1, 8, 11}}, - {2, {2, 6, 8, 11}}, - {3, {3, 6, 8, 11}}, - {4, {4, 6, 8, 11}}, - {5, {5, 6, 8, 11}}, - {6, {6, 8, 11}}, - {7, {7, 8, 11}}, - {8, {8, 11}}, - {9, {9, 11}}, - {10, {10, 11}}, - {11, {11}}}; - - EXPECT_EQ(post_dominators(g), answer); -} - -TEST(imm_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map answer = {{1, 1}, // no immediate dominator - {2, 1}, - {3, 2}, - {4, 2}, - {5, 4}, - {6, 2}, - {7, 1}, - {8, 1}, - {9, 8}, - {10, 8}, - {11, 8}}; - - EXPECT_EQ(imm_dominators(g), answer); -} - -TEST(imm_post_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map answer = { - {1, 8}, - {2, 6}, - {3, 6}, - {4, 6}, - {5, 6}, - {6, 8}, - {7, 8}, - {8, 11}, - {9, 11}, - {10, 11}, - {11, 11} // no immediate post - // dominator - }; - - EXPECT_EQ(imm_post_dominators(g), answer); -} - -TEST(imm_post_dominators, multisource) { - BasicGraph g; - - g.add_nodes({1, 2, 3, 4, 5}); - g.add_edges({{1, 3}, {2, 3}, {3, 4}, {3, 5}}); - - std::unordered_map answer = { - {-1, 3}, {1, 3}, {2, 3}, {3, 3}, {4, 4}, {5, 5}}; - - auto result = - imm_post_dominators>( - g); - EXPECT_EQ(result, answer); -} - -TEST(transitive_reduction, basic) { - BasicGraph g({1, 2, 3}, {{1, 2}, {2, 3}, {1, 3}}); - - BasicGraph answer({1, 2, 3}, {{1, 2}, {2, 3}}); - - auto result = transitive_reduction(g); - - EXPECT_EQ(result, answer); -} - -TEST(transitive_reduction, medium) { - BasicGraph g({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {1, 5}, - {2, 3}, - {2, 4}, - {2, 6}, - {3, 4}, - {4, 5}, - {4, 6}, - {5, 6}, - }); - - BasicGraph answer({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {2, 3}, - {3, 4}, - {4, 5}, - {5, 6}, - }); - - auto result = transitive_reduction(g); - - EXPECT_EQ(result, answer); -} - -TEST(inplace_transitive_reduction, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {1, 5}, - {2, 3}, - {2, 4}, - {2, 6}, - {3, 4}, - {4, 5}, - {4, 6}, - {5, 6}, - }); - - BasicGraph answer({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {2, 3}, - {3, 4}, - {4, 5}, - {5, 6}, - }); - - inplace_transitive_reduction(g); - - EXPECT_EQ(g, answer); -} - -TEST(roots, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - { - {1, 3}, - {2, 3}, - {3, 4}, - {3, 5}, - {3, 6}, - }); - - std::unordered_set answer{1, 2}; - - auto result = roots(g); - - EXPECT_EQ(result, answer); -} - -TEST(leaves, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 3}, {2, 3}, {3, 4}, {3, 5}, {3, 6}}); - - std::unordered_set answer{4, 5, 6}; - - auto result = leaves(g); - - EXPECT_EQ(result, answer); -} - -TEST(descendants, directed) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 2}, {2, 3}, {2, 4}, {3, 5}, {4, 5}}); - - std::unordered_set answer{2, 3, 4, 5}; - - auto result = descendants(g, 2); - - EXPECT_EQ(result, answer); -} - -TEST(descendants, undirected) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 2}, {2, 3}, {2, 4}, {3, 5}, {4, 5}}); - - std::unordered_set answer{1, 2, 3, 4, 5}; - - auto result = - descendants>(g, 2); - - EXPECT_EQ(result, answer); -} - -TEST(weakly_connected_components, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, {{1, 3}, {2, 3}, {4, 5}, {5, 4}}); - - std::unordered_set component1{1, 2, 3}; - std::unordered_set component2{4, 5}; - std::unordered_set component3{6}; - auto result = weakly_connected_components(g); - - EXPECT_EQ(result.size(), 3); - bool component1_found = false; - bool component2_found = false; - bool component3_found = false; - for (std::unordered_set &component : result) { - if (component.size() == component1.size()) { - component1_found = true; - EXPECT_EQ(component, component1); - } else if (component.size() == component2.size()) { - component2_found = true; - EXPECT_EQ(component, component2); - } else if (component.size() == component3.size()) { - component3_found = true; - EXPECT_EQ(component, component3); - } - } - - EXPECT_TRUE(component1_found); - EXPECT_TRUE(component2_found); - EXPECT_TRUE(component3_found); -} diff --git a/lib/compiler/test/test_dot.cc b/lib/compiler/test/test_dot.cc deleted file mode 100644 index 3212971255..0000000000 --- a/lib/compiler/test/test_dot.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "flexflow/utils/dot/record_formatter.h" -#include "gtest/gtest.h" - -TEST(record_formatters, basic) { - RecordFormatter rf, rf2, rf3; - std::ostringstream oss; - oss << "Wo" - << "rld"; - rf << "Hello" - << "World" - << (rf2 << "Inner" - << "World" - << (rf3 << "Even" - << "More" - << "Inner World")) - << "Goodbye" << oss; - - std::ostringstream oss_final; - oss_final << rf; - EXPECT_EQ(oss_final.str(), - "{ Hello | World | { Inner | World | { Even | More | Inner World } " - "} | Goodbye | World }"); -} diff --git a/lib/compiler/test/test_dp.cc b/lib/compiler/test/test_dp.cc deleted file mode 100644 index 01e4189839..0000000000 --- a/lib/compiler/test/test_dp.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -struct TestCostEstimator : public ICostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - MachineView const &mv) const override { - return 0.1; - } - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override { - return 1; - } -}; - -TEST_CASE("optimal_cost") { - auto g(NodeLabelledMultiDiGraph::create< - UnorderedNodeLabelledMultiDiGraph>()); - - Node n0 = g.add_node(InputAttrs()); - Node n1 = g.add_node(RepartitionAttrs(ff_dim_t(0), 2)); - Node n2 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 0)); - Node n3 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 1)); - Node n4 = g.add_node(ConcatAttrs(ff_dim_t(1))); - Node n5 = g.add_node(CombineAttrs(ff_dim_t(0), 2)); - - MultiDiEdge e0(n0, n1, 0, 0); - MultiDiEdge e1(n1, n2, 0, 0); - MultiDiEdge e2(n1, n3, 1, 0); - MultiDiEdge e3(n2, n4, 0, 0); - MultiDiEdge e4(n3, n4, 0, 1); - MultiDiEdge e5(n4, n5, 0, 0); - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - - OptimizerPCG pcg = infer_tensor_shape(g); - auto allowed_machine_views = [](PCGOperatorAttrs const &, - MachineResource const &) { - // TODO - return std::unordered_set{}; - }; - MachineResource resource(1, 1, 2); - Strategy s = - optimal_cost(pcg, allowed_machine_views, TestCostEstimator{}, resource); - - // TODO: check result -} diff --git a/lib/compiler/test/test_labelled_open_graph.cc b/lib/compiler/test/test_labelled_open_graph.cc deleted file mode 100644 index 7d85514816..0000000000 --- a/lib/compiler/test/test_labelled_open_graph.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -TEST_CASE("get_subgraph_labelled_open_graph") { - auto g = LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>(); - - int t0 = 100000; - - Node n0 = g.add_node(0); - Node n1 = g.add_node(1); - Node n2 = g.add_node(2); - Node n3 = g.add_node(3); - Node n4 = g.add_node(4); - - MultiDiEdge e0(n0, n1, 0, 0); - MultiDiEdge e1(n0, n2, 1, 0); - MultiDiEdge e2(n1, n3, 0, 0); - MultiDiEdge e3(n2, n3, 0, 1); - MultiDiEdge e4(n3, n4, 0, 0); - OutputMultiDiEdge e5({n4.value(), t0}, n4, 0); - - g.add_edge(e0, 0); - g.add_edge(e1, 1); - g.add_edge(e2, 2); - g.add_edge(e3, 3); - g.add_edge(e4, 4); - g.add_edge(e5, 5); - - auto subgraph0 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::INCLUDE_INPUTS, - OutputSettings::INCLUDE_OUTPUTS); - auto subgraph1 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::INCLUDE_INPUTS, - OutputSettings::EXCLUDE_OUTPUTS); - auto subgraph2 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::EXCLUDE_INPUTS, - OutputSettings::INCLUDE_OUTPUTS); - auto subgraph3 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::EXCLUDE_INPUTS, - OutputSettings::EXCLUDE_OUTPUTS); - - CHECK(get_nodes(subgraph0) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph1) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph2) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph3) == std::unordered_set{n3, n4}); - - std::unordered_set input_set{split_edge(e2).second, - split_edge(e3).second}; - std::unordered_set output_set{e5}; - - CHECK(get_inputs(subgraph0) == input_set); - CHECK(get_inputs(subgraph1) == input_set); - CHECK(get_inputs(subgraph2).empty()); - CHECK(get_inputs(subgraph3).empty()); - - CHECK(get_outputs(subgraph0) == output_set); - CHECK(get_outputs(subgraph1).empty()); - CHECK(get_outputs(subgraph2) == output_set); - CHECK(get_outputs(subgraph3).empty()); - - CHECK(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5}); - CHECK(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4}); - CHECK(get_edges(subgraph2) == std::unordered_set{e4, e5}); - CHECK(get_edges(subgraph3) == std::unordered_set{e4}); -} diff --git a/lib/compiler/test/test_machine_view.cc b/lib/compiler/test/test_machine_view.cc deleted file mode 100644 index eea084db48..0000000000 --- a/lib/compiler/test/test_machine_view.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "flexflow/config.h" -#include "flexflow/machine_view.h" -#include "gtest/gtest.h" - -using namespace Legion; -using namespace FlexFlow; - -TEST(machine_view_get_domain, basic) { - MachineView mv; - mv.ndims = 1; - mv.start_device_id = 2; - mv.dim[0] = 2; - mv.stride[0] = 1; - - Domain d; - d.dim = 1; - d.rect_data[0] = 0; - d.rect_data[0 + d.dim] = - 1; // Domain is includes, MachineView is exclusive on hi - - EXPECT_EQ(mv.get_domain(), d); -} - -TEST(machine_view_get_device_id, basic) { - MachineView mv; - mv.ndims = 1; - mv.start_device_id = 2; - mv.dim[0] = 2; - mv.stride[0] = 1; - - EXPECT_EQ(mv.get_device_id({0}), 2); - EXPECT_EQ(mv.get_device_id({1}), 3); -} diff --git a/lib/compiler/test/test_open_graph.cc b/lib/compiler/test/test_open_graph.cc deleted file mode 100644 index d96cdec467..0000000000 --- a/lib/compiler/test/test_open_graph.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -TEST_CASE("get_source_sink_open_graph:basic") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - - int s0 = 100000; - - Node n0 = g.add_node(); - - g.add_edge(InputMultiDiEdge({s0, n0.value()}, n0, 0)); - - CHECK(get_closed_sources(g) == std::unordered_set{}); - CHECK(get_closed_sinks(g) == std::unordered_set{n0}); - - CHECK(get_open_sources(g) == std::unordered_set{n0}); - CHECK(get_open_sinks(g) == std::unordered_set{}); -} - -TEST_CASE("get_source_sink_open_graph:unconnected") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - int s0 = 100000; - int t0 = s0 + 1; - - Node n0 = g.add_node(); - Node n1 = g.add_node(); - - g.add_edge(InputMultiDiEdge({s0, n0.value()}, n0, 0)); - g.add_edge(OutputMultiDiEdge({n1.value(), t0}, n1, 0)); - - /* - g: ->n0 - n1-> - */ - - CHECK(get_closed_sources(g) == std::unordered_set{n1}); - CHECK(get_closed_sinks(g) == std::unordered_set{n0}); - - CHECK(get_open_sources(g) == std::unordered_set{n0}); - CHECK(get_open_sinks(g) == std::unordered_set{n1}); -} - -TEST_CASE("get_source_sink_open_graph:complex") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - int s0 = 100000; - int s1 = s0 + 1; - int t0 = s1 + 1; - int t1 = t0 + 1; - - std::vector ns; - for (int i = 0; i < 8; ++i) { - ns.push_back(g.add_node()); - } - - g.add_edge(InputMultiDiEdge({s0, ns[0].value()}, ns[0], 0)); - g.add_edge(MultiDiEdge(ns[0], ns[1], 0, 0)); - g.add_edge(OutputMultiDiEdge({ns[1].value(), t0}, ns[1], 0)); - g.add_edge(OutputMultiDiEdge({ns[1].value(), t1}, ns[1], 1)); - - g.add_edge(MultiDiEdge(ns[2], ns[3], 0, 0)); - g.add_edge(MultiDiEdge(ns[2], ns[4], 1, 0)); - g.add_edge(MultiDiEdge(ns[4], ns[3], 0, 1)); - g.add_edge(OutputMultiDiEdge({ns[3].value(), t1}, ns[3], 0)); - - g.add_edge(InputMultiDiEdge({s0, ns[5].value()}, ns[5], 0)); - g.add_edge(InputMultiDiEdge({s1, ns[5].value()}, ns[5], 1)); - g.add_edge(MultiDiEdge(ns[5], ns[6], 0, 0)); - g.add_edge(MultiDiEdge(ns[6], ns[7], 0, 0)); - - CHECK(get_closed_sources(g) == std::unordered_set{ns[2]}); - CHECK(get_closed_sinks(g) == std::unordered_set{ns[7]}); - - CHECK(get_open_sources(g) == std::unordered_set{ns[1], ns[5]}); - CHECK(get_open_sinks(g) == std::unordered_set{ns[1], ns[3]}); -} - -TEST_CASE("get_cut") { - auto g = LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>; - - std::vector ns = add_nodes(g, 5); - - int t0 = 100000; - - MultiDiEdge e0(ns[0], ns[1], 0, 0); - MultiDiEdge e1(ns[1], ns[2], 0, 0); - MultiDiEdge e2(ns[1], ns[3], 1, 0); - MultiDiEdge e3(ns[2], ns[4], 0, 0); - MultiDiEdge e4(ns[3], ns[4], 0, 1); - OutputMultiDiEdge e5({ns[4].value(), t0}, ns[4], 0); - - GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; - CHECK(get_cut(g, gs0) == std::unordered_set{e1, e2}); - - GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; - CHECK(get_cut(g, gs1) == std::unordered_set{e3, e4}); -} diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc deleted file mode 100644 index 2d9414ba27..0000000000 --- a/lib/compiler/test/test_optimal_cost.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "test_cost_estimator.h" -#include "test_generator.h" - -/* -Tests whether optimal_cost can give a valid result given random PCG, trivial -allowed machine views, trivial cost estimator and random machine specification. -*/ -TEST_CASE("optimal_cost") { - auto test_allowed_machine_views = [](Operator const &, - MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(0, 1, 1)}; - }; - rc::check([](ParallelComputationGraph const &g, - MachineSpecification const &machine_spec) { - OptimalCostCache cached_subgraph_costs; - OptimalCostResult result = optimal_cost(g, - test_allowed_machine_views, - TestCostEstimator{}, - machine_spec, - cached_subgraph_costs); - RC_ASSERT(result.runtime > 0); - RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); - }); -} diff --git a/lib/compiler/test/test_parallel_config.cc b/lib/compiler/test/test_parallel_config.cc deleted file mode 100644 index 843879bb0d..0000000000 --- a/lib/compiler/test/test_parallel_config.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "flexflow/config.h" -#include "flexflow/model.h" -#include "gtest/gtest.h" - -using namespace FlexFlow; - -TEST(change_data_parallel_dimensionality, basic_reduce) { - ParallelConfig pc = get_basic_data_parallel_config(8, 4); - - ParallelConfig expected = get_basic_data_parallel_config(8, 2); - - ParallelConfig result = pc.change_data_parallel_dimensionality(2); - - EXPECT_EQ(result, expected); -} - -TEST(change_data_parallel_dimensionality, basic_expand) { - ParallelConfig pc = get_basic_data_parallel_config(8, 2); - - ParallelConfig expected = get_basic_data_parallel_config(8, 4); - - ParallelConfig result = pc.change_data_parallel_dimensionality(4); - - EXPECT_EQ(result, expected); -} diff --git a/lib/compiler/test/test_random_utils.cc b/lib/compiler/test/test_random_utils.cc deleted file mode 100644 index c7b4f9e5c2..0000000000 --- a/lib/compiler/test/test_random_utils.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "flexflow/utils/random_utils.h" -#include "gtest/gtest.h" - -TEST(select_random, basic) { - std::vector values{1, 2, 3, 4}; - std::vector weights{0.1, 0.2, 0.3, 0.4}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.05), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.25), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 3); - EXPECT_EQ(select_random_determistic(values, weights, 0.9), 4); -} - -TEST(select_random, bounds) { - std::vector values{1, 2, 3}; - std::vector weights{0.2, 0.3, 0.5}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.0), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.2), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 3); - EXPECT_EQ(select_random_determistic(values, weights, 1.0), 3); -} - -TEST(select_random, singleton) { - std::vector values{1}; - std::vector weights{1.0}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.0), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 1); - EXPECT_EQ(select_random_determistic(values, weights, 1.0), 1); -} - -TEST(select_random, empty) { - std::vector values{}; - std::vector weights{}; - EXPECT_THROW(select_random_determistic(values, weights, 0.5), - std::invalid_argument); -} - -TEST(select_random, unnormalized_weights) { - std::vector values{1, 2, 3}; - std::vector weights{1.0, 2.0, 2.0}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.1), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.9), 3); -} diff --git a/lib/compiler/test/test_substitution_loader.cc b/lib/compiler/test/test_substitution_loader.cc deleted file mode 100644 index b0531b598a..0000000000 --- a/lib/compiler/test/test_substitution_loader.cc +++ /dev/null @@ -1,144 +0,0 @@ -#include "flexflow/substitution.h" -#include "flexflow/substitution_loader.h" -#include "gtest/gtest.h" - -namespace sl = FlexFlow::substitution_loader; -// using namespace FlexFlow::substitution_loader; -using json = nlohmann::json; -using FlexFlow::PCG::create_xfer; -using FlexFlow::PCG::create_xfers; -using FlexFlow::PCG::GraphXfer; - -TEST(substitution_loader, basic) { - // Yes, I know this substitution is not correct. It's just for testing. - - sl::Rule example_rule; - - example_rule.name = "test_rule"; - - sl::Tensor input_tensor1; - input_tensor1.opId = -1; - input_tensor1.tsId = 0; - - sl::Tensor input_tensor2; - input_tensor2.opId = -2; - input_tensor2.tsId = 0; - - sl::Operator srcOp1; - srcOp1.op_type = OP_EW_ADD; - srcOp1.input = {input_tensor1, input_tensor2}; - srcOp1.para = {}; - - sl::Tensor srcOp1Output; - srcOp1Output.opId = 0; - srcOp1Output.tsId = 0; - - sl::Parameter activation_constraint; - activation_constraint.key = PM_ACTI; - activation_constraint.value = AC_MODE_NONE; - - sl::Operator srcOp2; - srcOp2.op_type = OP_LINEAR; - srcOp2.input = {srcOp1Output}; - srcOp2.para = {activation_constraint}; - - sl::Operator dstOp1; - dstOp1.op_type = OP_LINEAR; - dstOp1.input = {input_tensor1}; - dstOp1.para = {activation_constraint}; - - sl::Tensor dstOp1Output; - dstOp1Output.opId = 0; - dstOp1Output.tsId = 0; - - sl::Operator dstOp2; - dstOp2.op_type = OP_LINEAR; - dstOp2.input = {input_tensor2}; - dstOp2.para = {activation_constraint}; - - sl::Tensor dstOp2Output; - dstOp2Output.opId = 1; - dstOp2Output.tsId = 0; - - sl::Operator dstOp3; - dstOp3.op_type = OP_EW_ADD; - dstOp3.input = {dstOp1Output, dstOp2Output}; - dstOp3.para = {}; - - sl::MapOutput map_output; - map_output.srcOpId = 1; - map_output.srcTsId = 0; - map_output.dstOpId = 2; - map_output.dstTsId = 0; - - example_rule.srcOp = {srcOp1, srcOp2}; - example_rule.dstOp = {dstOp1, dstOp2, dstOp3}; - example_rule.mappedOutput = {map_output}; - - GraphXfer *xfer = new GraphXfer(nullptr); - create_xfer(*xfer, example_rule, 2); - - EXPECT_EQ(xfer->name, "test_rule"); - - EXPECT_EQ(xfer->srcOps.size(), 2); - EXPECT_EQ(xfer->srcOps[0]->type, OP_EW_ADD); - EXPECT_EQ(xfer->srcOps[1]->type, OP_LINEAR); - EXPECT_EQ(xfer->srcOps[0]->inputs.size(), 2); - EXPECT_NE(xfer->srcOps[0]->inputs[0], xfer->srcOps[0]->inputs[1]); - EXPECT_EQ(xfer->srcOps[0]->outputs.size(), 1); - EXPECT_EQ(xfer->srcOps[1]->inputs.size(), 1); - EXPECT_EQ(xfer->srcOps[0]->outputs[0], xfer->srcOps[1]->inputs[0]); - EXPECT_EQ(xfer->srcOps[1]->outputs.size(), 1); - - EXPECT_EQ(xfer->dstOps.size(), 3); - EXPECT_EQ(xfer->dstOps[0]->type, OP_LINEAR); - EXPECT_EQ(xfer->dstOps[1]->type, OP_LINEAR); - EXPECT_EQ(xfer->dstOps[2]->type, OP_EW_ADD); - EXPECT_EQ(xfer->dstOps[0]->inputs.size(), 1); - EXPECT_EQ(xfer->dstOps[0]->outputs.size(), 1); - EXPECT_EQ(xfer->dstOps[0]->inputs[0], xfer->srcOps[0]->inputs[0]); - EXPECT_EQ(xfer->dstOps[1]->inputs.size(), 1); - EXPECT_EQ(xfer->dstOps[1]->outputs.size(), 1); - EXPECT_EQ(xfer->dstOps[1]->inputs[0], xfer->srcOps[0]->inputs[1]); - EXPECT_EQ(xfer->dstOps[2]->inputs.size(), 2); - EXPECT_EQ(xfer->dstOps[2]->inputs[0], xfer->dstOps[0]->outputs[0]); - EXPECT_EQ(xfer->dstOps[2]->inputs[1], xfer->dstOps[1]->outputs[0]); - EXPECT_NE(xfer->dstOps[2]->inputs[0], xfer->dstOps[2]->inputs[1]); - EXPECT_EQ(xfer->dstOps[2]->outputs.size(), 1); - - EXPECT_EQ(xfer->mappedOutputs.size(), 1); - EXPECT_NE(xfer->srcOps[1]->outputs[0], xfer->dstOps[2]->outputs[0]); - EXPECT_EQ(xfer->mappedOutputs.at(xfer->srcOps[1]->outputs[0]), - xfer->dstOps[2]->outputs[0]); -} - -TEST(substitution_loader, operator_deserialization) { - json j = { - {"_t", "Operator"}, - {"input", - std::vector{{{"_t", "Tensor"}, {"opId", -2}, {"tsId", 0}}, - {{"_t", "Tensor"}, {"opId", -3}, {"tsId", 0}}}}, - {"para", std::vector{}}, - {"type", "OP_EW_ADD"}, - }; - - sl::Operator o; - from_json(j, o); - - EXPECT_EQ(o.op_type, OP_EW_ADD); - EXPECT_EQ(o.input.size(), 2); - EXPECT_EQ(o.input[0].opId, -2); - EXPECT_EQ(o.input[0].tsId, 0); - EXPECT_EQ(o.input[1].opId, -3); - EXPECT_EQ(o.input[1].tsId, 0); - EXPECT_EQ(o.para.size(), 0); -} - -// TEST(substitution_loader, load_full_file) { -// sl::RuleCollection collection = -// sl::load_rule_collection_from_path("tests/unit/graph_subst_3_v2.json"); -// EXPECT_EQ(collection.rules.size(), 640); - -// std::vector xfers = create_xfers(nullptr, collection, 2); -// EXPECT_EQ(xfers.size(), 640); -// } diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc index d649856152..9d007e2f45 100644 --- a/lib/op-attrs/src/get_output_shapes.cc +++ b/lib/op-attrs/src/get_output_shapes.cc @@ -5,6 +5,12 @@ namespace FlexFlow { ParallelTensorShape as_parallel(TensorShape const &); std::vector as_parallel(std::vector const &); +std::vector get_output_shapes( + PCGOperatorAttrs const &op_params, + std::vector const &input_tensor_shapes) { + NOT_IMPLEMENTED(); +} + // TensorShape get_output_shape(AggregateAttrs const &attrs, // TensorShape const &gate_preds, // TensorShape const &gate_assign, diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 55f80e3cc0..1b2a02b070 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -11,22 +11,21 @@ struct BandwidthNetworkModelConfig int bandwidth; }; -struct MachineSpecification : public use_visitable_cmp { +struct MachineSpecification { int num_nodes; int num_cpus_per_node; int num_gpus_per_node; float inter_node_bandwidth; - float intra_node_bandwidth; + req intra_node_bandwidth; }; -} // namespace FlexFlow +FF_VISITABLE_STRUCT(MachineSpecification, + num_nodes, + num_cpus_per_node, + num_gpus_per_node, + inter_node_bandwidth, + intra_node_bandwidth); -VISITABLE_STRUCT(::FlexFlow::MachineSpecification, - num_nodes, - num_cpus_per_node, - num_gpus_per_node, - inter_node_bandwidth, - intra_node_bandwidth); -MAKE_VISIT_HASHABLE(::FlexFlow::MachineSpecification); +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index b482e851d8..afd4206eb1 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -12,9 +12,9 @@ namespace FlexFlow { -struct MachineView : public use_visitable_cmp { - MachineView() = delete; - MachineView(device_id_t const &, StridedRectangle const &); +struct MachineView { + // MachineView() = delete; + // MachineView(device_id_t const &, StridedRectangle const &); std::vector device_ids() const; diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index c7a49bb57e..d09e25dcf3 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -17,11 +17,12 @@ struct Operator : public use_visitable_cmp { public: PCGOperatorAttrs attrs; + optional name; }; } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::Operator, attrs); +VISITABLE_STRUCT(::FlexFlow::Operator, attrs, name); MAKE_VISIT_HASHABLE(::FlexFlow::Operator); namespace FlexFlow { diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 28331f441c..25f85ffc48 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -17,7 +17,7 @@ struct side_size_t : public strong_typedef { using strong_typedef::strong_typedef; }; -struct StridedRectangleSide : public use_visitable_cmp { +struct StridedRectangleSide { public: StridedRectangleSide() = delete; StridedRectangleSide(num_points_t const &, int stride); @@ -32,13 +32,15 @@ struct StridedRectangleSide : public use_visitable_cmp { public: num_points_t num_points; - int stride; + req stride; }; -struct StridedRectangle : public use_visitable_cmp { +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, num_points, stride); + +struct StridedRectangle { public: - StridedRectangle() = delete; - StridedRectangle(std::vector const &); + // StridedRectangle() = delete; + // StridedRectangle(std::vector const &); size_t at(FFOrdered const &) const; StridedRectangleSide at(ff_dim_t const &) const; @@ -47,6 +49,9 @@ struct StridedRectangle : public use_visitable_cmp { public: FFOrdered sides; }; + +FF_VISITABLE_STRUCT(StridedRectangle, sides); + } // namespace FlexFlow MAKE_TYPEDEF_HASHABLE(::FlexFlow::num_points_t); @@ -55,10 +60,10 @@ MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); -VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); -MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); +// VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); +// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); -VISITABLE_STRUCT(::FlexFlow::StridedRectangle, sides); -MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangle); +// VISITABLE_STRUCT(::FlexFlow::StridedRectangle, sides); +// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangle); #endif diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index 9edfb09a8e..688ba1628f 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,8 +3,8 @@ namespace FlexFlow { -MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) - : start(start), rect(rect) {} +// MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) +// : start(start), rect(rect) {} static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stop > start); diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc index 8c79c67464..5cba8584c9 100644 --- a/lib/pcg/src/operator.cc +++ b/lib/pcg/src/operator.cc @@ -4,7 +4,7 @@ namespace FlexFlow { Operator::Operator(PCGOperatorAttrs const &attrs, optional const &name) - : attrs(attrs) {} + : attrs(attrs), name(name) {} Operator::operator PCGOperatorAttrs() const { return attrs; diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc new file mode 100644 index 0000000000..609b10edd2 --- /dev/null +++ b/lib/pcg/src/parallel_computation_graph.cc @@ -0,0 +1,37 @@ +#include "pcg/parallel_computation_graph.h" +#include "utils/graph/algorithms.h" + +namespace FlexFlow { + +bool operator==(ParallelComputationGraph const &lhs, ParallelComputationGraph const &rhs) { + return std::hash{}(lhs) == std::hash{}(rhs); +} + +} + +namespace std { + +size_t hash::operator()(FlexFlow::ParallelComputationGraph const &g) const { + using namespace FlexFlow; + + size_t h = 0; + + std::vector ordered_nodes = get_topological_ordering(g.value()); + hash_combine(h, ordered_nodes.size()); + + std::unordered_map node_index; + for (int i = 0; i < ordered_nodes.size(); ++i) { + node_index[ordered_nodes[i]] = i; + hash_combine(h, g.value().at(ordered_nodes[i])); + } + + for (MultiDiEdge const &edge : get_edges(g.value())) { + hash_combine(h, node_index.at(edge.src)); + hash_combine(h, node_index.at(edge.dst)); + hash_combine(h, g.value().at(edge)); + } + + return h; +} + +} \ No newline at end of file diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 29dcae6151..2792db65fe 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -30,8 +30,8 @@ side_size_t StridedRectangleSide::get_size() const { NOT_IMPLEMENTED(); } -StridedRectangle::StridedRectangle( - std::vector const &sides) - : sides(sides) {} +// StridedRectangle::StridedRectangle( +// std::vector const &sides) +// : sides(sides) {} } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph.h b/lib/utils/include/utils/graph/digraph.h index 4d0014596e..bfe6884c57 100644 --- a/lib/utils/include/utils/graph/digraph.h +++ b/lib/utils/include/utils/graph/digraph.h @@ -33,7 +33,7 @@ struct DiGraphView : virtual public GraphView { using GraphView::GraphView; private: - IDiGraphView &get_ptr() const; + IDiGraphView const &get_ptr() const; friend struct GraphInternal; }; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index bf037105b5..822973e149 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -54,9 +54,9 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { using MultiDiGraphView::MultiDiGraphView; private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -118,7 +118,7 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index b292a4ef0d..9d83cebac6 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -53,9 +53,8 @@ struct NodeLabelledOpenMultiDiGraphView using NodeLabelledMultiDiGraphView::NodeLabelledMultiDiGraphView; private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -119,7 +118,7 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 8c8a8b1a1b..501aa9caa4 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -54,6 +54,45 @@ struct OutputLabelledOpenMultiDiSubgraphView // CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); +template +struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMultiDiGraphView { + ViewOutputLabelledAsOutputLabelledOpen(OutputLabelledMultiDiGraphView const &g) : g(g) {} + + NodeLabel const &at(Node const &n) const override { + return g.at(n); + } + + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + assert(false); + } + + EdgeLabel const &at(MultiDiOutput const &o) const override { + return g.at(o); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); + } + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return transform(g.query_edges(q.standard_edge_query), + [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); + } + + ViewOutputLabelledAsOutputLabelledOpen* clone() const override { + return new ViewOutputLabelledAsOutputLabelledOpen(g); + } + +private: + OutputLabelledMultiDiGraphView const &g; +}; + +template +OutputLabelledOpenMultiDiGraphView view_output_labelled_as_output_labelled_open(OutputLabelledMultiDiGraphView const &g) { + return OutputLabelledOpenMultiDiGraphView::template create>(g); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index d0f94414b7..9b3d982e75 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -14,7 +14,7 @@ struct IOutputLabelledMultiDiGraphView IOutputLabelledMultiDiGraphView & operator=(IOutputLabelledMultiDiGraphView const &) = delete; - virtual OutputLabel const &at(MultiDiOutput const &) = 0; + virtual OutputLabel const &at(MultiDiOutput const &) const = 0; using INodeLabelledMultiDiGraphView::at; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); @@ -31,11 +31,11 @@ struct OutputLabelledMultiDiGraphView OutputLabelledMultiDiGraphView & operator=(OutputLabelledMultiDiGraphView const &) = default; - NodeLabel const &at(Node const &n) const { + virtual NodeLabel const &at(Node const &n) const { return get_ptr().at(n); } - OutputLabel const &at(MultiDiOutput const &o) const { + virtual OutputLabel const &at(MultiDiOutput const &o) const { return get_ptr().at(o); } @@ -56,13 +56,11 @@ struct OutputLabelledMultiDiGraphView } protected: - OutputLabelledMultiDiGraphView(cow_ptr_t ptr) - : NodeLabelledMultiDiGraphView(ptr) {} + using NodeLabelledMultiDiGraphView::NodeLabelledMultiDiGraphView; private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -81,7 +79,7 @@ struct OutputLabelledMultiDiGraph Node add_node(NodeLabel const &l) { Node n = get_ptr().add_node(); - nl->add_label(n, l); + nl.get_mutable()->add_label(n, l); return n; } @@ -93,12 +91,12 @@ struct OutputLabelledMultiDiGraph return nl.get_mutable()->get_label(n); } - NodeLabel const &at(Node const &n) const { + NodeLabel const &at(Node const &n) const override { return nl->get_label(n); } void add_output(MultiDiOutput const &o, OutputLabel const &l) { - ol->add_label(o, l); + ol.get_mutable()->add_label(o, l); }; void add_edge(MultiDiOutput const &o, MultiDiInput const &i) { @@ -110,16 +108,17 @@ struct OutputLabelledMultiDiGraph } OutputLabel &at(MultiDiOutput const &o) { - return ol->get_label(o); + return ol.get_mutable()->get_label(o); } - OutputLabel const &at(MultiDiOutput const &o) const { + OutputLabel const &at(MultiDiOutput const &o) const override { return ol->get_label(o); } std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -139,12 +138,11 @@ struct OutputLabelledMultiDiGraph OutputLabelledMultiDiGraph(cow_ptr_t ptr, cow_ptr_t nl, cow_ptr_t ol) - : OutputLabelledMultiDiGraphView(ptr), nl(nl), - ol(ol) {} + : GraphView(ptr), nl(nl), ol(ol) {} private: Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 3d2ac9d601..986d337a57 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -17,7 +17,8 @@ struct IOutputLabelledOpenMultiDiGraphView template struct OutputLabelledOpenMultiDiGraphView - : virtual NodeLabelledOpenMultiDiGraphView { + : virtual NodeLabelledOpenMultiDiGraphView, + virtual OutputLabelledMultiDiGraphView { private: using Interface = IOutputLabelledOpenMultiDiGraphView; @@ -59,12 +60,10 @@ struct OutputLabelledOpenMultiDiGraphView protected: using NodeLabelledOpenMultiDiGraphView< NodeLabel>::NodeLabelledOpenMultiDiGraphView; - OutputLabelledOpenMultiDiGraphView(cow_ptr_t ptr) : GraphView(ptr) {} private: - Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -96,6 +95,11 @@ struct OutputLabelledOpenMultiDiGraph return n; } + void add_node_unsafe(Node const &n, NodeLabel const &l) { + get_ptr().add_node_unsafe(n); + nl.get_mutable()->add_label(n, l); + } + NodePort add_node_port() { return get_ptr().add_node_port(); } @@ -121,14 +125,14 @@ struct OutputLabelledOpenMultiDiGraph } EdgeLabel &at(MultiDiOutput const &o) { - return ol->get_label(o); + return ol.get_mutable()->get_label(o); } EdgeLabel const &at(MultiDiOutput const &o) const override { return ol->get_label(o); } EdgeLabel &at(InputMultiDiEdge const &e) { - return il->get_label(e); + return il.get_mutable()->get_label(e); } EdgeLabel const &at(InputMultiDiEdge const &e) const override { @@ -165,7 +169,7 @@ struct OutputLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl), il(il), ol(ol) {} Interface &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 70e5e87f93..ae9b02c911 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -32,19 +32,19 @@ struct LabelledMultiDiGraphView operator=(LabelledMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr()->at(n); + return get_ptr().at(n); } EdgeLabel const &at(MultiDiEdge const &e) const { - return get_ptr()->at(e); + return get_ptr().at(e); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr()->query_nodes(q); + return get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr()->query_edges(q); + return get_ptr().query_edges(q); } template @@ -58,8 +58,8 @@ struct LabelledMultiDiGraphView protected: LabelledMultiDiGraphView(cow_ptr_t ptr) : NodeLabelledMultiDiGraphView(ptr) {} - cow_ptr_t get_ptr() const { - return cow_ptr_t(static_cast(*GraphView::ptr)); + Interface const &get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -84,11 +84,11 @@ struct LabelledMultiDiGraph } NodePort add_node_port() { - return this->get_ptr()->add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl->get_label(n); + return nl.get_mutable()->get_label(n); } NodeLabel const &at(Node const &n) const { @@ -96,20 +96,20 @@ struct LabelledMultiDiGraph } void add_edge(MultiDiEdge const &e, EdgeLabel const &l) { - return this->get_ptr()->add_edge(e, l); + return this->get_ptr().add_edge(e, l); } EdgeLabel &at(MultiDiEdge const &e) { - return el->get_label(e); + return el.get_mutable()->get_label(e); } EdgeLabel const &at(MultiDiEdge const &e) const { return el->get_label(e); } std::unordered_set query_nodes(NodeQuery const &q) const { - return this->get_ptr()->query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return this->get_ptr()->query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -129,8 +129,8 @@ struct LabelledMultiDiGraph cow_ptr_t el) : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} - cow_ptr_t get_ptr() const { - return cow_ptr_t(static_cast(*GraphView::ptr)); + Interface& get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/unordered_label.h b/lib/utils/include/utils/graph/labelled/unordered_label.h index 94c4bffe11..230e286ef8 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_label.h +++ b/lib/utils/include/utils/graph/labelled/unordered_label.h @@ -19,7 +19,8 @@ struct UnorderedLabelling : virtual public ILabelling { } void add_label(Elem const &e, Label const &l) { - label_map.insert({e, l}); + auto p = std::make_pair(e, l); + label_map.insert(p); } UnorderedLabelling *clone() const { diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 5a227d46ec..82a45a2ad0 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -52,7 +52,7 @@ struct ViewMultiDiGraphAsOutputLabelled return node_label(n); } - virtual OutputLabel &at(MultiDiOutput const &o) override { + virtual OutputLabel const &at(MultiDiOutput const &o) const override { return output_label(o); } @@ -86,6 +86,26 @@ Impl materialize_output_labelled_multidigraph_view( return result; } +template +OutputLabelledOpenMultiDiGraph materialize_output_labelled_open_multidigraph_view(OutputLabelledOpenMultiDiGraphView const &g) { + OutputLabelledOpenMultiDiGraph result = OutputLabelledOpenMultiDiGraph::template create(); + for (Node const &n : get_nodes(g)) { + result.add_node_unsafe(n, g.at(n)); + } + for (OpenMultiDiEdge const &e : get_edges(g)) { + result.add_edge(e); + if (is_input_edge(e)) { + InputMultiDiEdge input_edge = get(e); + result.add_label(input_edge, g.at(input_edge)); + } else { + MultiDiOutput output = is_standard_edge(e) ? static_cast(get(e)) : static_cast(get(e)); + auto tensor = g.at(output); + result.add_label(output, tensor); + } + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h index cfb2c7db21..d5d72bbbd7 100644 --- a/lib/utils/include/utils/graph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.h @@ -30,7 +30,7 @@ struct MultiDiGraphView : virtual DiGraphView { using DiGraphView::DiGraphView; private: - IMultiDiGraphView &get_ptr() const; + IMultiDiGraphView const &get_ptr() const; friend struct GraphInternal; }; diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h index 1f8a3692fa..703ad6778f 100644 --- a/lib/utils/include/utils/graph/open_graphs.h +++ b/lib/utils/include/utils/graph/open_graphs.h @@ -34,7 +34,7 @@ struct OpenMultiDiGraphView : virtual MultiDiGraphView { using MultiDiGraphView::MultiDiGraphView; private: - IOpenMultiDiGraphView &get_ptr() const; + IOpenMultiDiGraphView const &get_ptr() const; friend struct GraphInternal; }; @@ -50,6 +50,7 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { Node add_node(); void add_node_unsafe(Node const &); void remove_node_unsafe(Node const &); + NodePort add_node_port(); void add_edge(Edge const &); void remove_edge(Edge const &); @@ -60,7 +61,7 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { static typename std::enable_if::value, OpenMultiDiGraph>::type create() { - return make_cow_ptr(); + return OpenMultiDiGraph(make_cow_ptr()); } private: @@ -96,7 +97,7 @@ struct UpwardOpenMultiDiGraphView : virtual MultiDiGraphView { private: using MultiDiGraphView::MultiDiGraphView; - IUpwardOpenMultiDiGraphView &get_ptr() const; + IUpwardOpenMultiDiGraphView const &get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraphView); @@ -158,7 +159,7 @@ struct DownwardOpenMultiDiGraphView : virtual MultiDiGraphView { private: using MultiDiGraphView::MultiDiGraphView; - Interface &get_ptr() const; + Interface const &get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraphView); diff --git a/lib/utils/include/utils/graph/undirected.h b/lib/utils/include/utils/graph/undirected.h index 24cd07caa9..b32b6e3572 100644 --- a/lib/utils/include/utils/graph/undirected.h +++ b/lib/utils/include/utils/graph/undirected.h @@ -55,7 +55,7 @@ struct UndirectedGraphView : virtual GraphView { friend struct GraphInternal; private: - IUndirectedGraphView &get_ptr() const; + IUndirectedGraphView const &get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraphView); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index d62989d65b..d5407adbae 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -164,7 +164,9 @@ DiGraphView apply_contraction(DiGraphView const &g, for (auto const &kv : nodes) { Node from = kv.first; Node into = kv.second; - contractedView = contract_node(contractedView, from, into); + if (from != into) { + contractedView = contract_node(contractedView, from, into); + } } return contractedView; } @@ -347,6 +349,13 @@ std::unordered_set }); } +std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { + return transform(g.query_edges(OutputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +} +std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { + return transform(g.query_edges(InputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +} + std::unordered_map> get_predecessors(DiGraphView const &g, std::unordered_set const &nodes) { @@ -757,4 +766,28 @@ std::unordered_set> return components; } +std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return get_incoming_edges(g, n).size() == 0; + }); +} + +std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return get_outgoing_edges(g, n).size() == 0; + }); +} + +std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})).empty(); + }); +} + +std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})).empty(); + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index ff65df1cf6..1e2f562c19 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -13,9 +13,8 @@ std::unordered_set return get_ptr().query_edges(query); } -IDiGraphView &DiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IDiGraphView const &DiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node DiGraph::add_node() { @@ -48,6 +47,6 @@ std::unordered_set } IDiGraph &DiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 07d5837b1e..7bbe4cae67 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -23,9 +23,8 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IMultiDiGraphView &MultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node MultiDiGraph::add_node() { @@ -66,7 +65,7 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph &MultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index a9635aa553..836b5513e9 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -53,7 +53,7 @@ std::unordered_set Graph::query_nodes(NodeQuery const &q) const { } IGraph &Graph::get_ptr() const { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 0acda5e6f6..9bbb1bfa3d 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -21,9 +21,8 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IOpenMultiDiGraphView &OpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node OpenMultiDiGraph::add_node() { @@ -51,8 +50,12 @@ std::unordered_set return this->get_ptr().query_edges(q); } +NodePort OpenMultiDiGraph::add_node_port() { + return get_ptr().add_node_port(); +} + IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -67,9 +70,9 @@ std::unordered_set return get_ptr().query_edges(q); } -IUpwardOpenMultiDiGraphView &UpwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IUpwardOpenMultiDiGraphView const &UpwardOpenMultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node UpwardOpenMultiDiGraph::add_node() { @@ -98,7 +101,7 @@ std::unordered_set UpwardOpenMultiDiGraph::query_edges( } IUpwardOpenMultiDiGraph &UpwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -113,9 +116,9 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IDownwardOpenMultiDiGraphView &DownwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); +IDownwardOpenMultiDiGraphView const &DownwardOpenMultiDiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node DownwardOpenMultiDiGraph::add_node() { @@ -150,7 +153,7 @@ std::unordered_set } IDownwardOpenMultiDiGraph &DownwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 3b3a1b0aed..41ecf3c436 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -72,7 +72,7 @@ std::unordered_set if (include_src == SourceSettings::INCLUDE_SOURCE_NODES) { result = set_union(result, srcs); } - if (include_sink == SinkSettings::EXCLUDE_SINK_NODES) { + if (include_sink == SinkSettings::INCLUDE_SINK_NODES) { result = set_union(result, sinks); } return result; @@ -103,12 +103,12 @@ SplitAST sp_decomposition(DiGraphView const &g) { sources, {bottleneck.value()}, SourceSettings::INCLUDE_SOURCE_NODES, - SinkSettings::INCLUDE_SINK_NODES)), + SinkSettings::EXCLUDE_SINK_NODES)), sp_decomposition(source_to_sink_subgraph( g, {bottleneck.value()}, sinks, - SourceSettings::EXCLUDE_SOURCE_NODES, + SourceSettings::INCLUDE_SOURCE_NODES, SinkSettings::INCLUDE_SINK_NODES))); } else { return parallel_decomposition(g); @@ -195,6 +195,13 @@ struct ToFinalAST { variant to_final_ast(SplitAST const &ast) { return visit(ToFinalAST{}, ast); } + +SerialParallelDecomposition + get_serial_parallel_decomposition(DiGraphView const &g) { + SplitAST ast = sp_decomposition(g); + return to_final_ast(ast); +} + struct GetNodes { template std::unordered_set operator()(T const &t) { diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index 414b350a89..166a9efa36 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -26,7 +26,7 @@ void UndirectedGraph::remove_edge(UndirectedEdge const &e) { } IUndirectedGraph &UndirectedGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -50,8 +50,8 @@ std::unordered_set return this->get_ptr().query_nodes(q); } -IUndirectedGraphView &UndirectedGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( +IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index 062dca6858..5a8c6e9f93 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -469,7 +469,7 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), inputs(inputs) {} + : g(g), nodes(nodes), inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { return new UpwardOpenMultiDiSubgraphView(g, nodes); @@ -477,11 +477,11 @@ UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { std::unordered_set UpwardOpenMultiDiSubgraphView::query_edges( OpenMultiDiEdgeQuery const &q) const { - std::unordered_set result = - g.query_edges(OpenMultiDiEdgeQuery( - q.input_edge_query.with_dst_nodes(nodes), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - OutputMultiDiEdgeQuery::none())); + OpenMultiDiEdgeQuery subgraph_query( + q.input_edge_query.with_dst_nodes(nodes), + q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), + OutputMultiDiEdgeQuery::none()); + std::unordered_set result = g.query_edges(subgraph_query); extend(result, query_edge(inputs, q.input_edge_query.with_dst_nodes(nodes))); return result; } @@ -493,16 +493,16 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) {} + : g(g), nodes(nodes), outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( OpenMultiDiEdgeQuery const &q) const { - std::unordered_set result = - g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - q.output_edge_query.with_src_nodes(nodes))); + OpenMultiDiEdgeQuery subgraph_query( + InputMultiDiEdgeQuery::none(), + q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), + q.output_edge_query.with_src_nodes(nodes)); + std::unordered_set result = g.query_edges(subgraph_query); extend(result, query_edge(outputs, q.output_edge_query.with_src_nodes(nodes))); return result; From 6211b84e84a8ba471783d7b7c4f6854b4d59c884 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 15 Nov 2023 17:13:45 -0500 Subject: [PATCH 03/37] format --- lib/compiler/include/compiler/cost_estimate.h | 3 +- .../include/compiler/machine_mapping.h | 18 ++--- .../include/compiler/unity_algorithm.h | 2 +- lib/compiler/src/graph_utils.cc | 9 ++- lib/compiler/src/machine_mapping.cc | 67 +++++++++++-------- lib/compiler/src/unity_algorithm.cc | 11 ++- lib/compiler/test/src/test_generator.h | 2 +- .../test/src/test_labelled_open_graph.cc | 11 +-- lib/compiler/test/src/test_open_graph.cc | 6 +- lib/compiler/test/src/test_optimal_cost.cc | 31 +++++---- lib/op-attrs/src/get_output_shapes.cc | 2 +- .../include/pcg/parallel_computation_graph.h | 5 +- lib/pcg/include/pcg/strided_rectangle.h | 4 +- lib/pcg/src/machine_view.cc | 3 +- lib/pcg/src/parallel_computation_graph.cc | 13 ++-- .../include/substitutions/substitution.h | 4 +- lib/utils/include/utils/graph/algorithms.h | 6 +- .../utils/graph/labelled/node_labelled.h | 6 +- .../utils/graph/labelled/node_labelled_open.h | 6 +- .../include/utils/graph/labelled/open_views.h | 21 ++++-- .../utils/graph/labelled/output_labelled.h | 5 +- .../graph/labelled/output_labelled_open.h | 9 +-- .../utils/graph/labelled/standard_labelled.h | 2 +- .../include/utils/graph/labelled/views.h | 25 +++++-- lib/utils/src/graph/algorithms.cc | 20 ++++-- lib/utils/src/graph/multidigraph.cc | 3 +- lib/utils/src/graph/open_graphs.cc | 6 +- lib/utils/src/graph/views.cc | 6 +- 28 files changed, 189 insertions(+), 117 deletions(-) diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h index 3791292529..557f51a7ca 100644 --- a/lib/compiler/include/compiler/cost_estimate.h +++ b/lib/compiler/include/compiler/cost_estimate.h @@ -45,7 +45,8 @@ struct CostEstimator { } private: - CostEstimator(std::shared_ptr implementation_ptr) : implementation_ptr(implementation_ptr) {} + CostEstimator(std::shared_ptr implementation_ptr) + : implementation_ptr(implementation_ptr) {} std::shared_ptr implementation_ptr; }; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 9f9d97937d..aa9152dcd6 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -9,13 +9,14 @@ namespace FlexFlow { -using SubParallelComputationGraphView = OutputLabelledOpenMultiDiGraphView; +using SubParallelComputationGraphView = + OutputLabelledOpenMultiDiGraphView; struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); static bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - + req> machine_views; }; FF_VISITABLE_STRUCT(MachineMapping, machine_views); @@ -24,11 +25,10 @@ struct OptimalCostState { SerialParallelDecomposition subgraph; req resource; // req> given_machine_views; - // req> frontier_machine_views; + // req> + // frontier_machine_views; }; -FF_VISITABLE_STRUCT(OptimalCostState, - subgraph, - resource); +FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, @@ -72,9 +72,11 @@ namespace std { template <> struct hash> { - size_t operator()(std::unordered_map const &g) const; + size_t operator()( + std::unordered_map const &g) + const; }; -}; +}; // namespace std #endif diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 81e8375948..a87bddcc3a 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -45,6 +45,6 @@ struct hash { size_t operator()(FlexFlow::Strategy const &) const; }; -} +} // namespace std #endif diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 04e96c66ed..e0134d6dd8 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -15,11 +15,10 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { auto g = pcg.value(); auto g_ = view_output_labelled_as_output_labelled_open(g); auto subpcg = materialize_output_labelled_open_multidigraph_view< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling - >(g_); + AdjacencyOpenMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling, + UnorderedLabelling>(g_); return subpcg; } diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 671c59a94f..2bdd7de1e2 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -94,7 +94,8 @@ GraphSplit float estimate_cost(SubParallelComputationGraphView const &g, CostEstimator const &estimator, MachineMapping const &device_mapping, - std::unordered_map const &frontier_machine_views) { + std::unordered_map const + &frontier_machine_views) { return 0.1; } @@ -103,19 +104,20 @@ void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { } struct OptimalCost { - OptimalCost( - SubParallelComputationGraphView const &g, - CostEstimator const &cost_estimator, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const &frontier_machine_views, - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) + OptimalCost(SubParallelComputationGraphView const &g, + CostEstimator const &cost_estimator, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + OptimalCostCache &cached_subgraph_costs) : g(g), cost_estimator(cost_estimator), resource(resource), given_machine_views(restrict_keys(given_machine_views, get_nodes(g))), - frontier_machine_views(restrict_keys(frontier_machine_views, get_edges(g))), + frontier_machine_views( + restrict_keys(frontier_machine_views, get_edges(g))), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} @@ -131,7 +133,8 @@ struct OptimalCost { template OptimalCostResult operator()(T const &t) const { - OptimalCostState state{t, resource/*, given_machine_views, frontier_machine_views*/}; + OptimalCostState state{ + t, resource /*, given_machine_views, frontier_machine_views*/}; optional cached_result = cached_subgraph_costs.load(state); @@ -150,8 +153,10 @@ struct OptimalCost { SerialParallelDecomposition post_decompn = decomposed.second; GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); - SubParallelComputationGraphView pre_graph = get_subgraph(g, graph_split.first); - SubParallelComputationGraphView post_graph = get_subgraph(g, graph_split.second); + SubParallelComputationGraphView pre_graph = + get_subgraph(g, graph_split.first); + SubParallelComputationGraphView post_graph = + get_subgraph(g, graph_split.second); std::unordered_set post_graph_sources = get_closed_sources(post_graph); @@ -165,9 +170,11 @@ struct OptimalCost { for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { - std::unordered_map new_given_machine_views = given_machine_views; + std::unordered_map new_given_machine_views = + given_machine_views; new_given_machine_views.emplace(split_point, mv); - std::unordered_map new_frontier_machine_views = frontier_machine_views; + std::unordered_map + new_frontier_machine_views = frontier_machine_views; new_frontier_machine_views.emplace(split_edge, mv); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( @@ -198,8 +205,10 @@ struct OptimalCost { SerialParallelDecomposition decompn2 = decomposed.second; GraphSplit graph_split = get_graph_split(decompn1, decompn2); - SubParallelComputationGraphView g1 = get_subgraph(g, graph_split.first), - g2 = get_subgraph(g, graph_split.second); + SubParallelComputationGraphView g1 = get_subgraph( + g, graph_split.first), + g2 = get_subgraph( + g, graph_split.second); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( visit(OptimalCost(g1, @@ -225,16 +234,16 @@ struct OptimalCost { visit(OptimalCost(g1, cost_estimator, resource_split.first, - given_machine_views, - frontier_machine_views, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn1), visit(OptimalCost(g2, cost_estimator, resource_split.second, - given_machine_views, - frontier_machine_views, + given_machine_views, + frontier_machine_views, allowed_machine_views, cached_subgraph_costs), decompn2))); @@ -248,13 +257,16 @@ struct OptimalCost { assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); MachineMapping mv_map{given_machine_views}; - return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}; + return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), + mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { MachineMapping mv_map{{{node, mv}}}; - minimize_runtime(optimal_result, - {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}); + minimize_runtime( + optimal_result, + {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), + mv_map}); } return optimal_result; } @@ -269,7 +281,8 @@ OptimalCostResult CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - SerialParallelDecomposition sp_decomposition = get_serial_parallel_decomposition(g); + SerialParallelDecomposition sp_decomposition = + get_serial_parallel_decomposition(g); SubParallelComputationGraph subpcg = pcg_to_subpcg(g); return visit(OptimalCost(subpcg, cost_estimator, diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 16671b080a..c89bf04b25 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -37,8 +37,13 @@ Strategy DeduplicatedPriorityQueue, StrategyRuntimeCmp> candidates; - OptimalCostResult initial_pcg_result = optimal_cost(pcg, allowed_machine_views, cost_estimator, resources, cached_subgraph_costs); - Strategy initial_result{pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; + OptimalCostResult initial_pcg_result = optimal_cost(pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); + Strategy initial_result{ + pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; Strategy best_result = initial_result; candidates.push(initial_result); @@ -88,4 +93,4 @@ size_t hash::operator()(FlexFlow::Strategy const &s) const { return h; } -} +} // namespace std diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index 23a79abbe0..6566c8c2de 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_TEST_GENERATOR_H #include "compiler/machine_mapping.h" -#include "substitutions/sub_parallel_computation_graph.h" #include "pcg/computation_graph.h" #include "rapidcheck.h" +#include "substitutions/sub_parallel_computation_graph.h" using namespace FlexFlow; diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 78ea1ece55..82c247e0d2 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -67,11 +67,12 @@ TEST_CASE("get_subgraph_open_graph") { CHECK(bool(get_open_outputs(subgraph3).empty())); CHECK(bool(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5})); + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); CHECK(bool(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4})); - CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == + std::unordered_set{e4, e5})); CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); } diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc index ea1108c291..00cb4ca890 100644 --- a/lib/compiler/test/src/test_open_graph.cc +++ b/lib/compiler/test/src/test_open_graph.cc @@ -11,7 +11,8 @@ TEST_CASE("get_source_sink_open_graph") { Node n0 = g.add_node(); NodePort p0 = g.add_node_port(); - InputMultiDiEdge e0{n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; + InputMultiDiEdge e0{ + n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; g.add_edge(e0); CHECK(bool(get_closed_sources(g) == std::unordered_set{})); @@ -63,7 +64,8 @@ TEST_CASE("get_cut") { MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; - OutputMultiDiEdge e5{ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; + OutputMultiDiEdge e5{ + ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; g.add_edge(e0); g.add_edge(e1); diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 87f9d06342..7eeb118c57 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -27,25 +27,28 @@ allowed machine views, trivial cost estimator and random machine specification. // } TEST_CASE("optimal_cost_0") { - auto pcg = OutputLabelledMultiDiGraph::template create< - AdjacencyMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling - >(); + auto pcg = + OutputLabelledMultiDiGraph::template create< + AdjacencyMultiDiGraph, + UnorderedLabelling, + UnorderedLabelling>(); Node n0 = pcg.add_node(Operator(InputAttrs{}, "input")); - Node n1 = pcg.add_node(Operator(LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, "linear")); + Node n1 = pcg.add_node(Operator( + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, + "linear")); MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); pcg.add_output(e, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); auto test_allowed_machine_views = [](Operator const &, MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + return std::unordered_set{ + make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; CostEstimator estimator = CostEstimator::create(); @@ -54,7 +57,11 @@ TEST_CASE("optimal_cost_0") { OptimalCostCache cached_results; - OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), test_allowed_machine_views, estimator, machine_spec, cached_results); + OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); CHECK(bool(result.runtime > 0)); -} \ No newline at end of file +} diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc index 9d007e2f45..c20d4be34c 100644 --- a/lib/op-attrs/src/get_output_shapes.cc +++ b/lib/op-attrs/src/get_output_shapes.cc @@ -8,7 +8,7 @@ std::vector as_parallel(std::vector const &); std::vector get_output_shapes( PCGOperatorAttrs const &op_params, std::vector const &input_tensor_shapes) { - NOT_IMPLEMENTED(); + NOT_IMPLEMENTED(); } // TensorShape get_output_shape(AggregateAttrs const &attrs, diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 2342cd08fa..39a69a80ab 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -15,7 +15,8 @@ struct ParallelComputationGraph }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); -bool operator==(ParallelComputationGraph const &, ParallelComputationGraph const &); +bool operator==(ParallelComputationGraph const &, + ParallelComputationGraph const &); } // namespace FlexFlow @@ -25,6 +26,6 @@ template <> struct hash { size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; }; -} +} // namespace std #endif diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 25f85ffc48..179fff080f 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -35,7 +35,9 @@ struct StridedRectangleSide { req stride; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, num_points, stride); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, + num_points, + stride); struct StridedRectangle { public: diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index 688ba1628f..f146482141 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,7 +3,8 @@ namespace FlexFlow { -// MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) +// MachineView::MachineView(device_id_t const &start, StridedRectangle const +// &rect) // : start(start), rect(rect) {} static StridedRectangle make_1d_rect(int start, int stop, int stride) { diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc index 609b10edd2..011c40eb4c 100644 --- a/lib/pcg/src/parallel_computation_graph.cc +++ b/lib/pcg/src/parallel_computation_graph.cc @@ -3,15 +3,18 @@ namespace FlexFlow { -bool operator==(ParallelComputationGraph const &lhs, ParallelComputationGraph const &rhs) { - return std::hash{}(lhs) == std::hash{}(rhs); +bool operator==(ParallelComputationGraph const &lhs, + ParallelComputationGraph const &rhs) { + return std::hash{}(lhs) == + std::hash{}(rhs); } -} +} // namespace FlexFlow namespace std { -size_t hash::operator()(FlexFlow::ParallelComputationGraph const &g) const { +size_t hash::operator()( + FlexFlow::ParallelComputationGraph const &g) const { using namespace FlexFlow; size_t h = 0; @@ -34,4 +37,4 @@ size_t hash::operator()(FlexFlow::ParallelCo return h; } -} \ No newline at end of file +} // namespace std diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 98471a8fbd..8dbe4e66cf 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -28,12 +28,12 @@ SubParallelComputationGraph } // namespace FlexFlow -namespace std{ +namespace std { template <> struct hash { size_t operator()(FlexFlow::Substitution const &) const; }; -}; +}; // namespace std #endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index fc3d219e57..cee5445190 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -106,8 +106,10 @@ std::unordered_set get_node_edges(UndirectedGraphView const &, std::unordered_set get_outputs(MultiDiGraphView const &); std::unordered_set get_inputs(MultiDiGraphView const &); -std::unordered_set get_open_outputs(OpenMultiDiGraphView const &); -std::unordered_set get_open_inputs(OpenMultiDiGraphView const &); +std::unordered_set + get_open_outputs(OpenMultiDiGraphView const &); +std::unordered_set + get_open_inputs(OpenMultiDiGraphView const &); std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 822973e149..64de380f9c 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -55,8 +55,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -118,8 +117,7 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 9d83cebac6..ff069060ca 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -10,7 +10,8 @@ struct INodeLabelledOpenMultiDiGraphView : virtual INodeLabelledMultiDiGraphView, virtual IOpenMultiDiGraphView { INodeLabelledOpenMultiDiGraphView() = default; - INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = delete; + INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = + delete; INodeLabelledOpenMultiDiGraphView & operator=(INodeLabelledOpenMultiDiGraphView const &) = delete; }; @@ -118,8 +119,7 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 501aa9caa4..8323fcf9dc 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -43,7 +43,7 @@ struct OutputLabelledOpenMultiDiSubgraphView return SubgraphView(g, nodes).query_edges(q); } - OutputLabelledOpenMultiDiSubgraphView* clone() const override { + OutputLabelledOpenMultiDiSubgraphView *clone() const override { return new OutputLabelledOpenMultiDiSubgraphView(g, nodes); } @@ -55,8 +55,11 @@ struct OutputLabelledOpenMultiDiSubgraphView // CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); template -struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMultiDiGraphView { - ViewOutputLabelledAsOutputLabelledOpen(OutputLabelledMultiDiGraphView const &g) : g(g) {} +struct ViewOutputLabelledAsOutputLabelledOpen + : virtual IOutputLabelledOpenMultiDiGraphView { + ViewOutputLabelledAsOutputLabelledOpen( + OutputLabelledMultiDiGraphView const &g) + : g(g) {} NodeLabel const &at(Node const &n) const override { return g.at(n); @@ -77,10 +80,10 @@ struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMulti std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const override { return transform(g.query_edges(q.standard_edge_query), - [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); + [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); } - ViewOutputLabelledAsOutputLabelledOpen* clone() const override { + ViewOutputLabelledAsOutputLabelledOpen *clone() const override { return new ViewOutputLabelledAsOutputLabelledOpen(g); } @@ -89,8 +92,12 @@ struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMulti }; template -OutputLabelledOpenMultiDiGraphView view_output_labelled_as_output_labelled_open(OutputLabelledMultiDiGraphView const &g) { - return OutputLabelledOpenMultiDiGraphView::template create>(g); +OutputLabelledOpenMultiDiGraphView + view_output_labelled_as_output_labelled_open( + OutputLabelledMultiDiGraphView const &g) { + return OutputLabelledOpenMultiDiGraphView:: + template create< + ViewOutputLabelledAsOutputLabelledOpen>(g); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 9b3d982e75..03f44fc5ee 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -118,7 +118,7 @@ struct OutputLabelledMultiDiGraph std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -142,8 +142,7 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; cow_ptr_t ol; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 986d337a57..17cf7fa7af 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -169,8 +169,7 @@ struct OutputLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl), il(il), ol(ol) {} Interface &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } cow_ptr_t nl; @@ -179,8 +178,10 @@ struct OutputLabelledOpenMultiDiGraph }; template -void add_label(OutputLabelledOpenMultiDiGraph &g, OpenMultiDiEdge const &e, EdgeLabel const &l) { - visit([&](const auto &e) { g.add_label(e, l); }, e); +void add_label(OutputLabelledOpenMultiDiGraph &g, + OpenMultiDiEdge const &e, + EdgeLabel const &l) { + visit([&](auto const &e) { g.add_label(e, l); }, e); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index ae9b02c911..d47d7fdbc0 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -129,7 +129,7 @@ struct LabelledMultiDiGraph cow_ptr_t el) : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} - Interface& get_ptr() const { + Interface &get_ptr() const { return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 82a45a2ad0..84bb20d327 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_VIEWS_H #include "node_labelled.h" -#include "standard_labelled.h" #include "output_labelled_open.h" +#include "standard_labelled.h" namespace FlexFlow { @@ -86,9 +86,21 @@ Impl materialize_output_labelled_multidigraph_view( return result; } -template -OutputLabelledOpenMultiDiGraph materialize_output_labelled_open_multidigraph_view(OutputLabelledOpenMultiDiGraphView const &g) { - OutputLabelledOpenMultiDiGraph result = OutputLabelledOpenMultiDiGraph::template create(); +template +OutputLabelledOpenMultiDiGraph + materialize_output_labelled_open_multidigraph_view( + OutputLabelledOpenMultiDiGraphView const &g) { + OutputLabelledOpenMultiDiGraph result = + OutputLabelledOpenMultiDiGraph::template create< + Impl, + NodeLabelImpl, + InputLabelImpl, + OutputLabelImpl>(); for (Node const &n : get_nodes(g)) { result.add_node_unsafe(n, g.at(n)); } @@ -98,7 +110,10 @@ OutputLabelledOpenMultiDiGraph materialize_output_labell InputMultiDiEdge input_edge = get(e); result.add_label(input_edge, g.at(input_edge)); } else { - MultiDiOutput output = is_standard_edge(e) ? static_cast(get(e)) : static_cast(get(e)); + MultiDiOutput output = + is_standard_edge(e) + ? static_cast(get(e)) + : static_cast(get(e)); auto tensor = g.at(output); result.add_label(output, tensor); } diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index d5407adbae..72242709e2 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -349,11 +349,17 @@ std::unordered_set }); } -std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { - return transform(g.query_edges(OutputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +std::unordered_set + get_open_outputs(OpenMultiDiGraphView const &g) { + return transform( + g.query_edges(OutputMultiDiEdgeQuery::all()), + [](OpenMultiDiEdge const &e) { return get(e); }); } -std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { - return transform(g.query_edges(InputMultiDiEdgeQuery::all()), [](OpenMultiDiEdge const &e) { return get(e); }); +std::unordered_set + get_open_inputs(OpenMultiDiGraphView const &g) { + return transform( + g.query_edges(InputMultiDiEdgeQuery::all()), + [](OpenMultiDiEdge const &e) { return get(e); }); } std::unordered_map> @@ -780,13 +786,15 @@ std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})).empty(); + return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})) + .empty(); }); } std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})).empty(); + return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})) + .empty(); }); } diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 7bbe4cae67..d0ed98b29b 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -24,7 +24,8 @@ std::unordered_set } IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node MultiDiGraph::add_node() { diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 9bbb1bfa3d..d545b45a31 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,8 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } Node OpenMultiDiGraph::add_node() { @@ -116,7 +117,8 @@ std::unordered_set return this->get_ptr().query_edges(q); } -IDownwardOpenMultiDiGraphView const &DownwardOpenMultiDiGraphView::get_ptr() const { +IDownwardOpenMultiDiGraphView const & + DownwardOpenMultiDiGraphView::get_ptr() const { return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index 5a8c6e9f93..bf4f7351c0 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -469,7 +469,8 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} + : g(g), nodes(nodes), + inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { return new UpwardOpenMultiDiSubgraphView(g, nodes); @@ -493,7 +494,8 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} + : g(g), nodes(nodes), + outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( From fb58a99913681970c01233a139ea3bfb34c00fea Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 23 Jan 2024 22:46:15 -0500 Subject: [PATCH 04/37] fmt --- lib/utils/include/utils/graph/labelled/node_labelled.h | 9 +++------ .../include/utils/graph/labelled/node_labelled_open.h | 9 +++------ lib/utils/include/utils/graph/labelled/output_labelled.h | 9 +++------ .../include/utils/graph/labelled/output_labelled_open.h | 9 +++------ .../include/utils/graph/labelled/standard_labelled.h | 9 +++------ lib/utils/src/graph/multidigraph.cc | 3 +-- lib/utils/src/graph/undirected.cc | 3 ++- 7 files changed, 18 insertions(+), 33 deletions(-) diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 109855965d..ded049f224 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -54,8 +54,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -117,13 +116,11 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index c77d75c37a..fab6695070 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,8 +55,7 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -120,13 +119,11 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index a675632a55..58b4ef23fd 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -60,8 +60,7 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -143,13 +142,11 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 18ffbf569d..231d70db74 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -63,8 +63,7 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -174,13 +173,11 @@ struct OutputLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl), il(il), ol(ol) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 1a98701811..8af69e18fc 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -60,8 +60,7 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -132,13 +131,11 @@ struct LabelledMultiDiGraph : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 41ae3e1aa3..771e01e573 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -66,8 +66,7 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index ab13cb5ef7..b1e8be7f14 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -56,7 +56,8 @@ std::unordered_set } IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); } } // namespace FlexFlow From 02937e1e584d110d2e1e89301889c393ea6526d2 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 24 Jan 2024 16:32:36 -0500 Subject: [PATCH 05/37] fix --- lib/compiler/src/graph_utils.cc | 8 +++----- .../include/utils/graph/labelled/output_labelled.h | 10 ++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index e0134d6dd8..12d5c99a34 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -12,14 +12,12 @@ ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { } SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - auto g = pcg.value(); - auto g_ = view_output_labelled_as_output_labelled_open(g); - auto subpcg = materialize_output_labelled_open_multidigraph_view< + return materialize_output_labelled_open_multidigraph_view< AdjacencyOpenMultiDiGraph, UnorderedLabelling, UnorderedLabelling, - UnorderedLabelling>(g_); - return subpcg; + UnorderedLabelling>( + view_output_labelled_as_output_labelled_open(pcg.value())); } std::vector diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 58b4ef23fd..f3cf14022b 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -39,11 +39,12 @@ struct OutputLabelledMultiDiGraphView return get_ptr().at(o); } - std::unordered_set query_nodes(NodeQuery const &q) const { + virtual std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { + virtual std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -115,11 +116,12 @@ struct OutputLabelledMultiDiGraph return ol->get_label(o); } - std::unordered_set query_nodes(NodeQuery const &q) const { + std::unordered_set query_nodes(NodeQuery const &q) const override { return get_ptr().query_nodes(q); } - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { + std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const override { return get_ptr().query_edges(q); } From 6402ed0a538a232a16d8d634f39c131e1ae9a495 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 25 Jan 2024 15:10:25 -0500 Subject: [PATCH 06/37] add substitutions, compiler, and their unit tests to CI --- .github/workflows/per-lib-check.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index f21621b265..4cbcdf6afb 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -83,3 +83,28 @@ jobs: run: | cd build make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) kernels + + - name: Build substitutions + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions + + - name: Build compiler + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler + + - name: Build substitutions-test + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions-test + + - name: Build compiler-test + run: | + cd build + make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler-test + + - name: Unit tests + run: | + cd build + ctest From 0c45f61f114414215c5eff7d39006f07735e2fe7 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 25 Jan 2024 16:04:47 -0500 Subject: [PATCH 07/37] disable runtime unit test --- lib/runtime/CMakeLists.txt | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/runtime/CMakeLists.txt b/lib/runtime/CMakeLists.txt index 49b052ec2b..fd5b4991ef 100644 --- a/lib/runtime/CMakeLists.txt +++ b/lib/runtime/CMakeLists.txt @@ -17,18 +17,18 @@ ff_add_library( pcg ) -ff_add_test_executable( - NAME - runtime-test - SRC_PATTERNS - test/src/*.cc - PUBLIC_INCLUDE - include/ - PRIVATE_INCLUDE - test/src/ src/ - DEPS - runtime - doctest -) +# ff_add_test_executable( +# NAME +# runtime-test +# SRC_PATTERNS +# test/src/*.cc +# PUBLIC_INCLUDE +# include/ +# PRIVATE_INCLUDE +# test/src/ src/ +# DEPS +# runtime +# doctest +# ) add_subdirectory(ffi) From 95fa427b8680e68acf33cccf96bbc66bb37cd1fa Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 15 Feb 2024 17:06:11 -0500 Subject: [PATCH 08/37] minor fix --- lib/compiler/src/machine_mapping.cc | 34 +++++- lib/compiler/src/unity_algorithm.cc | 8 +- lib/compiler/test/CMakeLists.txt | 4 +- .../test/src/test_labelled_open_graph.cc | 2 +- lib/compiler/test/src/test_open_graph.cc | 10 +- lib/compiler/test/src/test_optimal_cost.cc | 6 +- lib/pcg/include/pcg/machine_view.h | 3 - lib/pcg/include/pcg/operator.h | 19 ++-- lib/pcg/include/pcg/parallel_tensor.h | 2 + lib/pcg/include/pcg/strided_rectangle.h | 9 -- lib/pcg/src/machine_view.cc | 4 - lib/pcg/src/operator.cc | 6 +- lib/pcg/src/parallel_tensor.cc | 4 + lib/pcg/src/strided_rectangle.cc | 4 - lib/substitutions/src/substitution.cc | 106 +++++++++--------- lib/utils/include/utils/graph/algorithms.h | 1 + .../include/utils/graph/labelled/open_views.h | 2 - .../graph/labelled/output_labelled_open.h | 45 +++++--- .../utils/graph/labelled/unordered_label.h | 3 +- lib/utils/include/utils/variant.h | 8 ++ lib/utils/src/graph/algorithms.cc | 10 +- lib/utils/src/graph/open_graphs.cc | 2 +- lib/utils/src/graph/views.cc | 10 +- 23 files changed, 166 insertions(+), 136 deletions(-) diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 2bdd7de1e2..3cabd972bf 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -96,7 +96,39 @@ float estimate_cost(SubParallelComputationGraphView const &g, MachineMapping const &device_mapping, std::unordered_map const &frontier_machine_views) { - return 0.1; + float cost = 0; + for (Node const &node : get_nodes(g)) { + std::unordered_set incoming_edges = + get_incoming_edges(g, node); + std::vector inputs = + transform(as_vector(incoming_edges), + [&](UpwardOpenMultiDiEdge const &input_edge) { + return g.at(input_edge).get_shape(); + }); + cost += estimator.estimate_cost( + g.at(node).attrs, inputs, device_mapping.machine_views.at(node)); + } + + for (OpenMultiDiEdge const &edge : get_edges(g)) { + if (holds_alternative(edge)) { + cost += estimator.estimate_cost( + g.at(edge).get_shape(), + frontier_machine_views.at(edge), + device_mapping.machine_views.at(get(edge).dst)); + } else if (holds_alternative(edge)) { + cost += estimator.estimate_cost( + g.at(edge).get_shape(), + device_mapping.machine_views.at(get(edge).src), + frontier_machine_views.at(edge)); + } else { + assert(holds_alternative(edge)); + cost += estimator.estimate_cost( + g.at(edge).get_shape(), + device_mapping.machine_views.at(get(edge).src), + device_mapping.machine_views.at(get(edge).dst)); + } + } + return cost; } void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index c89bf04b25..3363aecc2f 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -9,11 +9,17 @@ bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { return lhs.runtime < rhs.runtime; } +/* + * Gets all substitutions applicable to a PCG + */ std::unordered_set get_all_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } +/* + * Applies a substitution to all possible positions in PCG + */ std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, Substitution const &) { @@ -53,7 +59,7 @@ Strategy Strategy const ¤t_result = candidates.top(); candidates.pop(); - if (StrategyRuntimeCmp{}(current_result, best_result)) { + if (current_result.runtime < best_result.runtime) { best_result = current_result; } else if (current_result.runtime > best_result.runtime * opt_config.alpha) { diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index cc64b15f7d..cbd7e233c0 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,9 +2,7 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/test_labelled_open_graph.cc - src/test_open_graph.cc - src/test_optimal_cost.cc + src/*.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 82c247e0d2..1cae9a0cd1 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,7 +4,7 @@ using namespace FlexFlow; -TEST_CASE("get_subgraph_open_graph") { +TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { auto g = OpenMultiDiGraph::create(); int t0 = 100000; diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc index 00cb4ca890..7436f213d7 100644 --- a/lib/compiler/test/src/test_open_graph.cc +++ b/lib/compiler/test/src/test_open_graph.cc @@ -7,8 +7,6 @@ using namespace FlexFlow; TEST_CASE("get_source_sink_open_graph") { OpenMultiDiGraph g = OpenMultiDiGraph::create(); - int s0 = 100000; - Node n0 = g.add_node(); NodePort p0 = g.add_node_port(); InputMultiDiEdge e0{ @@ -25,9 +23,6 @@ TEST_CASE("get_source_sink_open_graph") { TEST_CASE("get_source_sink_open_graph:unconnected") { OpenMultiDiGraph g = OpenMultiDiGraph::create(); - int s0 = 100000; - int t0 = s0 + 1; - Node n0 = g.add_node(); Node n1 = g.add_node(); @@ -54,10 +49,7 @@ TEST_CASE("get_source_sink_open_graph:unconnected") { TEST_CASE("get_cut") { auto g = OpenMultiDiGraph::create(); - std::vector ns; - for (int i = 0; i < 5; ++i) { - ns.push_back(g.add_node()); - } + std::vector ns = add_nodes(g, 5); MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 7eeb118c57..a6cd88a006 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -33,10 +33,10 @@ TEST_CASE("optimal_cost_0") { UnorderedLabelling, UnorderedLabelling>(); - Node n0 = pcg.add_node(Operator(InputAttrs{}, "input")); - Node n1 = pcg.add_node(Operator( + Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n1 = pcg.add_node(Operator{ LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, - "linear")); + "linear"}); MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index afd4206eb1..7521cd209a 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -13,9 +13,6 @@ namespace FlexFlow { struct MachineView { - // MachineView() = delete; - // MachineView(device_id_t const &, StridedRectangle const &); - std::vector device_ids() const; device_id_t at(FFOrdered const &coord) const; diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index d09e25dcf3..3eb7fb2a43 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -2,31 +2,26 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H #include "op-attrs/operator_attrs.h" -#include "utils/optional.h" #include "utils/stack_string.h" #include "utils/visitable.h" +#include + namespace FlexFlow { -struct Operator : public use_visitable_cmp { +struct Operator { public: - Operator() = delete; - Operator(PCGOperatorAttrs const &attrs, optional const &name); - operator PCGOperatorAttrs() const; public: PCGOperatorAttrs attrs; - optional name; + req> name; }; -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::Operator, attrs, name); -MAKE_VISIT_HASHABLE(::FlexFlow::Operator); +FF_VISITABLE_STRUCT(Operator, attrs, name); -namespace FlexFlow { static_assert(is_well_behaved_value_type::value, ""); -} + +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index eadc83d9fd..4594e849cf 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -47,6 +47,8 @@ struct ParallelTensor : public use_visitable_cmp { optional sync_type = nullopt, optional initializer = nullopt); + ParallelTensorShape get_shape() const; + public: ParallelTensorDims dims; DataType data_type; diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 179fff080f..d123d7c6ac 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -41,9 +41,6 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, struct StridedRectangle { public: - // StridedRectangle() = delete; - // StridedRectangle(std::vector const &); - size_t at(FFOrdered const &) const; StridedRectangleSide at(ff_dim_t const &) const; size_t num_dims() const; @@ -62,10 +59,4 @@ MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); -// VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); -// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); - -// VISITABLE_STRUCT(::FlexFlow::StridedRectangle, sides); -// MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangle); - #endif diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index f146482141..46f87833f0 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,10 +3,6 @@ namespace FlexFlow { -// MachineView::MachineView(device_id_t const &start, StridedRectangle const -// &rect) -// : start(start), rect(rect) {} - static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stop > start); assert(stride > 0); diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc index 5cba8584c9..81e7326a76 100644 --- a/lib/pcg/src/operator.cc +++ b/lib/pcg/src/operator.cc @@ -2,9 +2,9 @@ namespace FlexFlow { -Operator::Operator(PCGOperatorAttrs const &attrs, - optional const &name) - : attrs(attrs), name(name) {} +// Operator::Operator(PCGOperatorAttrs const &attrs, +// std::optional const &name) +// : attrs(attrs), name(name) {} Operator::operator PCGOperatorAttrs() const { return attrs; diff --git a/lib/pcg/src/parallel_tensor.cc b/lib/pcg/src/parallel_tensor.cc index a8d7b15ea9..8cc79d7293 100644 --- a/lib/pcg/src/parallel_tensor.cc +++ b/lib/pcg/src/parallel_tensor.cc @@ -10,4 +10,8 @@ ParallelTensor::ParallelTensor(ParallelTensorDims const &dims, : dims(dims), data_type(data_type), sync_type(sync_type), initializer(initializer), create_gradients(create_gradients) {} +ParallelTensorShape ParallelTensor::get_shape() const { + return ParallelTensorShape(dims, data_type); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 2792db65fe..7f612b743b 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -30,8 +30,4 @@ side_size_t StridedRectangleSide::get_size() const { NOT_IMPLEMENTED(); } -// StridedRectangle::StridedRectangle( -// std::vector const &sides) -// : sides(sides) {} - } // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 8e99624acb..635083b780 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -119,27 +119,27 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::OP_TYPE)); switch (op_type) { case Op::BATCHMATMUL: - return Operator( + return Operator{ BatchMatmulAttrs{ get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, - nullopt); + std::nullopt}; case Op::BATCHNORM: - return Operator( + return Operator{ BatchNormAttrs{get(assignments.at(OperatorAttributeKey::RELU))}, - nullopt); + std::nullopt}; case Op::CAST: - return Operator(CastAttrs{get( + return Operator{CastAttrs{get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, - nullopt); + std::nullopt}; case Op::CONCAT: - return Operator( + return Operator{ ConcatAttrs{ get(assignments.at(OperatorAttributeKey::AXIS)), get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - nullopt); + std::nullopt}; case Op::CONV2D: - return Operator( + return Operator{ Conv2DAttrs{ get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), get(assignments.at(OperatorAttributeKey::KERNEL_H)), @@ -151,13 +151,13 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::GROUPS)), get(assignments.at(OperatorAttributeKey::ACTIVATION)), get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - nullopt); + std::nullopt}; case Op::DROPOUT: - return Operator( + return Operator{ DropoutAttrs{get(assignments.at(OperatorAttributeKey::RATE)), get( assignments.at(OperatorAttributeKey::SEED))}, - nullopt); + std::nullopt}; case Op::EW_ADD: case Op::EW_DIV: case Op::EW_EQUAL: @@ -167,7 +167,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EW_MIN: case Op::EW_MUL: case Op::EW_SUB: - return Operator( + return Operator{ ElementBinaryAttrs{ op_type, get(assignments.at(OperatorAttributeKey::DATA_TYPE)), @@ -175,44 +175,44 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), get( assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - nullopt); + std::nullopt}; case Op::SCALAR_ADD: case Op::SCALAR_FLOOR_DIV: case Op::SCALAR_MULTIPLY: case Op::SCALAR_SUB: case Op::SCALAR_TRUE_DIV: - return Operator( + return Operator{ ElementScalarUnaryAttrs{ op_type, get(assignments.at(OperatorAttributeKey::SCALAR))}, - nullopt); + std::nullopt}; case Op::EMBEDDING: - return Operator( + return Operator{ EmbeddingAttrs{ get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), get(assignments.at(OperatorAttributeKey::AGGR)), get(assignments.at(OperatorAttributeKey::OP_TYPE))}, - nullopt); + std::nullopt}; case Op::FLAT: - return Operator(FlatAttrs{}, nullopt); + return Operator{FlatAttrs{}, std::nullopt}; case Op::GATHER: - return Operator( + return Operator{ GatherAttrs{get(assignments.at(OperatorAttributeKey::DIM))}, - nullopt); + std::nullopt}; case Op::INPUT: - return Operator(InputAttrs{}, nullopt); + return Operator{InputAttrs{}, std::nullopt}; case Op::LAYERNORM: - return Operator( + return Operator{ LayerNormAttrs{ get>( assignments.at(OperatorAttributeKey::AXES)), get( assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), get(assignments.at(OperatorAttributeKey::EPSILON))}, - nullopt); + std::nullopt}; case Op::LINEAR: - return Operator( + return Operator{ LinearAttrs{ get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), get(assignments.at(OperatorAttributeKey::USE_BIAS)), @@ -220,9 +220,9 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::ACTIVATION)), get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, - nullopt); + std::nullopt}; case Op::MULTIHEAD_ATTENTION: - return Operator( + return Operator{ MultiHeadAttentionAttrs{ get(assignments.at(OperatorAttributeKey::EMBED_DIM)), get(assignments.at(OperatorAttributeKey::NUM_HEADS)), @@ -232,11 +232,11 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::BIAS)), get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, - nullopt); + std::nullopt}; case Op::NOOP: - return Operator(NoopAttrs{}, nullopt); + return Operator{NoopAttrs{}, std::nullopt}; case Op::POOL2D: - return Operator( + return Operator{ Pool2DAttrs{ get(assignments.at(OperatorAttributeKey::KERNEL_H)), get(assignments.at(OperatorAttributeKey::KERNEL_W)), @@ -247,7 +247,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, get(assignments.at(OperatorAttributeKey::POOL_TYPE)), get( assignments.at(OperatorAttributeKey::ACTIVATION))}, - nullopt); + std::nullopt}; case Op::REDUCE_ARGMAX: case Op::REDUCE_ARGMIN: case Op::REDUCE_MAX: @@ -255,65 +255,65 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::REDUCE_MIN: case Op::REDUCE_PROD: case Op::REDUCE_SUM: - return Operator( + return Operator{ ReduceAttrs{ get>( assignments.at(OperatorAttributeKey::AXES)), op_type, get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - nullopt); + std::nullopt}; case Op::REVERSE: - return Operator(ReverseAttrs{get( + return Operator{ReverseAttrs{get( assignments.at(OperatorAttributeKey::AXIS))}, - nullopt); + std::nullopt}; case Op::RESHAPE: - return Operator(ReshapeAttrs{get( + return Operator{ReshapeAttrs{get( assignments.at(OperatorAttributeKey::SHAPE))}, - nullopt); + std::nullopt}; case Op::SPLIT: - return Operator( + return Operator{ SplitAttrs{get>( assignments.at(OperatorAttributeKey::SPLITS)), get(assignments.at(OperatorAttributeKey::AXIS))}, - nullopt); + std::nullopt}; case Op::SOFTMAX: - return Operator(SoftmaxAttrs{get( + return Operator{SoftmaxAttrs{get( assignments.at(OperatorAttributeKey::DIM))}, - nullopt); + std::nullopt}; case Op::TOPK: - return Operator( + return Operator{ TopKAttrs{get(assignments.at(OperatorAttributeKey::K)), get(assignments.at(OperatorAttributeKey::SORTED))}, - nullopt); + std::nullopt}; case Op::TRANSPOSE: - return Operator( + return Operator{ TransposeAttrs{get>( assignments.at(OperatorAttributeKey::PERMUTATION))}, - nullopt); + std::nullopt}; case Op::COMBINE: - return Operator( + return Operator{ CombineAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; case Op::REDUCTION: - return Operator( + return Operator{ ReductionAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; case Op::REPARTITION: - return Operator( + return Operator{ RepartitionAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; case Op::REPLICATE: - return Operator( + return Operator{ ReplicateAttrs{ get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + std::nullopt}; default: mk_runtime_error("Unknown Operator"); } diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index cee5445190..12aa2dccb0 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -23,6 +23,7 @@ std::vector add_nodes(Graph &, int); std::vector add_nodes(UndirectedGraph &, int); std::vector add_nodes(DiGraph &, int); std::vector add_nodes(MultiDiGraph &, int); +std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); std::vector add_node_ports(MultiDiGraph &, int); diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 8323fcf9dc..a24c2b940b 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -52,8 +52,6 @@ struct OutputLabelledOpenMultiDiSubgraphView std::unordered_set const &nodes; }; -// CHECK_NOT_ABSTRACT(OutputLabelledOpenMultiDiSubgraphView); - template struct ViewOutputLabelledAsOutputLabelledOpen : virtual IOutputLabelledOpenMultiDiGraphView { diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 231d70db74..1c1b28c6d6 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -40,6 +40,16 @@ struct OutputLabelledOpenMultiDiGraphView return get_ptr().at(o); } + template + EdgeLabel const &at(variant const &e) const { + return visit([&](auto const &e) -> auto const & { return this->at(e); }, e); + } + + template + EdgeLabel &at(variant const &e) { + return visit([&](auto const &e) -> auto & { return this->at(e); }, e); + } + std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } @@ -85,52 +95,52 @@ struct OutputLabelledOpenMultiDiGraph Node add_node(NodeLabel const &l) { Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); + this->node_labelling.get_mutable()->add_label(n, l); return n; } void add_node_unsafe(Node const &n, NodeLabel const &l) { - get_ptr().add_node_unsafe(n); - nl.get_mutable()->add_label(n, l); + this->get_ptr().add_node_unsafe(n); + this->node_labelling.get_mutable()->add_label(n, l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->node_labelling.get_mutable()->get_label(n); } NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->node_labelling->get_label(n); } void add_label(MultiDiOutput const &o, EdgeLabel const &l) { - ol.get_mutable()->add_label(o, l); + this->output_labelling.get_mutable()->add_label(o, l); }; void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) { - il.get_mutable()->add_label(e, l); + this->input_labelling.get_mutable()->add_label(e, l); } void add_edge(OpenMultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } EdgeLabel &at(MultiDiOutput const &o) { - return ol.get_mutable()->get_label(o); + return this->output_labelling.get_mutable()->get_label(o); } EdgeLabel const &at(MultiDiOutput const &o) const { - return ol->get_label(o); + return this->output_labelling->get_label(o); } EdgeLabel &at(InputMultiDiEdge const &e) { - return il.get_mutable()->get_label(e); + return this->input_labelling.get_mutable()->get_label(e); } EdgeLabel const &at(InputMultiDiEdge const &e) const { - return il->get_label(e); + return this->input_labelling->get_label(e); } template @@ -170,7 +180,8 @@ struct OutputLabelledOpenMultiDiGraph cow_ptr_t nl, cow_ptr_t il, cow_ptr_t ol) - : GraphView(ptr), nl(nl), il(il), ol(ol) {} + : GraphView(ptr), node_labelling(nl), input_labelling(il), + output_labelling(ol) {} Interface &get_ptr() { return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); @@ -180,9 +191,9 @@ struct OutputLabelledOpenMultiDiGraph return *std::dynamic_pointer_cast(GraphView::ptr.get()); } - cow_ptr_t nl; - cow_ptr_t il; - cow_ptr_t ol; + cow_ptr_t node_labelling; + cow_ptr_t input_labelling; + cow_ptr_t output_labelling; }; template diff --git a/lib/utils/include/utils/graph/labelled/unordered_label.h b/lib/utils/include/utils/graph/labelled/unordered_label.h index 230e286ef8..94c4bffe11 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_label.h +++ b/lib/utils/include/utils/graph/labelled/unordered_label.h @@ -19,8 +19,7 @@ struct UnorderedLabelling : virtual public ILabelling { } void add_label(Elem const &e, Label const &l) { - auto p = std::make_pair(e, l); - label_map.insert(p); + label_map.insert({e, l}); } UnorderedLabelling *clone() const { diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index b1a1dc1081..bb78719c9e 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -198,6 +198,14 @@ auto narrow(Container const &c) -> decltype(transform( return transform(c, [](VariantIn const &i) { return narrow(i); }); } +template ::value>> +auto narrow(Container const &c) { + return transform(c, [](VariantIn const &e) { return get(e); }); +} + template add_nodes(MultiDiGraph &g, int num_nodes) { return add_nodes_impl(g, num_nodes); } +std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes) { + return add_nodes_impl(g, num_nodes); +} + std::vector add_node_ports(MultiDiGraph &g, int num_node_ports) { std::vector node_ports; for (int i = 0; i < num_node_ports; i++) { @@ -786,15 +790,13 @@ std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(InputMultiDiEdgeQuery::all().with_dst_nodes({n})) - .empty(); + return !get_incoming_edges(g, n).empty(); }); } std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { return filter(get_nodes(g), [&](Node const &n) { - return !g.query_edges(OutputMultiDiEdgeQuery::all().with_src_nodes({n})) - .empty(); + return !get_outgoing_edges(g, n).empty(); }); } diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 5ab5858fd2..8355713506 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -52,7 +52,7 @@ std::unordered_set } NodePort OpenMultiDiGraph::add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() { diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index bf4f7351c0..dc823f7da4 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -469,8 +469,9 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), - inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)) {} + : g(g), nodes(nodes) { + inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); +} UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { return new UpwardOpenMultiDiSubgraphView(g, nodes); @@ -494,8 +495,9 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), - outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} + : g(g), nodes(nodes) { + outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); +} std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( From 1f7e2b6aebe0e22cbb8c64db07667d359f980553 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 18 Feb 2024 17:27:51 -0500 Subject: [PATCH 09/37] (not compilable) visitable issue for OptimalCostState --- lib/compiler/include/compiler/machine_mapping.h | 10 +++++----- lib/compiler/src/machine_mapping.cc | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index aa9152dcd6..7299404d90 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -23,12 +23,12 @@ FF_VISITABLE_STRUCT(MachineMapping, machine_views); struct OptimalCostState { SerialParallelDecomposition subgraph; - req resource; - // req> given_machine_views; - // req> - // frontier_machine_views; + MachineSpecification resource; + std::unordered_map given_machine_views; + req> + frontier_machine_views; }; -FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource); +FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource, given_machine_views, frontier_machine_views); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 3cabd972bf..fc89ff1306 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -166,7 +166,7 @@ struct OptimalCost { template OptimalCostResult operator()(T const &t) const { OptimalCostState state{ - t, resource /*, given_machine_views, frontier_machine_views*/}; + t, resource , given_machine_views, frontier_machine_views}; optional cached_result = cached_subgraph_costs.load(state); From a9a64020d9d7295065d66914ce3ffe301ed75aeb Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 27 Feb 2024 15:43:22 -0500 Subject: [PATCH 10/37] fix machine mapping hash & refactor dp algorithm --- .../include/compiler/machine_mapping.h | 9 +- lib/compiler/src/graph_utils.cc | 10 +- lib/compiler/src/graph_utils.h | 3 +- lib/compiler/src/machine_mapping.cc | 206 ++++++++++-------- lib/compiler/src/unity_algorithm.cc | 6 +- lib/utils/include/utils/containers.decl.h | 2 +- lib/utils/include/utils/containers.h | 2 +- lib/utils/include/utils/fmt.h | 8 + lib/utils/include/utils/hash-utils.h | 4 +- lib/utils/test/src/test_hash.cc | 18 ++ 10 files changed, 156 insertions(+), 112 deletions(-) create mode 100644 lib/utils/test/src/test_hash.cc diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 7299404d90..185f2706ef 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -25,10 +25,13 @@ struct OptimalCostState { SerialParallelDecomposition subgraph; MachineSpecification resource; std::unordered_map given_machine_views; - req> - frontier_machine_views; + req> frontier_machine_views; }; -FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource, given_machine_views, frontier_machine_views); +FF_VISITABLE_STRUCT(OptimalCostState, + subgraph, + resource, + given_machine_views, + frontier_machine_views); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 12d5c99a34..069ae4a41f 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -11,13 +11,9 @@ ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { NOT_IMPLEMENTED(); } -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - return materialize_output_labelled_open_multidigraph_view< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>( - view_output_labelled_as_output_labelled_open(pcg.value())); +SubParallelComputationGraphView + pcg_to_subpcg(ParallelComputationGraph const &pcg) { + return view_output_labelled_as_output_labelled_open(pcg.value()); } std::vector diff --git a/lib/compiler/src/graph_utils.h b/lib/compiler/src/graph_utils.h index 88515ef950..711a253b61 100644 --- a/lib/compiler/src/graph_utils.h +++ b/lib/compiler/src/graph_utils.h @@ -9,7 +9,8 @@ SerialParallelDecomposition get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); +SubParallelComputationGraphView + pcg_to_subpcg(ParallelComputationGraph const &g); // NOTE(@wmdi): I think we should have the following interfaces in the graph // library eventually. diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index fc89ff1306..5ce988b951 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -45,12 +45,10 @@ bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, optional OptimalCostCache::load(OptimalCostState const &state) const { - auto it = cache.find(state); - // if (contains_key(cache, state)) { - // // auto result = cache.at(state); - // OptimalCostResult result = OptimalCostResult::infinity(); - // return make_optional(result); - // } + if (contains_key(cache, state)) { + OptimalCostResult result = cache.at(state); + return make_optional(result); + } return nullopt; } @@ -135,51 +133,74 @@ void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { minimize(m1, m2, OptimalCostRuntimeCmp{}); } -struct OptimalCost { - OptimalCost(SubParallelComputationGraphView const &g, - CostEstimator const &cost_estimator, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views, - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) - : g(g), cost_estimator(cost_estimator), resource(resource), - given_machine_views(restrict_keys(given_machine_views, get_nodes(g))), - frontier_machine_views( - restrict_keys(frontier_machine_views, get_edges(g))), +struct MachineMappingSearcher { + MachineMappingSearcher( + CostEstimator cost_estimator, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + OptimalCostCache &cached_subgraph_costs) + : cost_estimator(cost_estimator), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} - SubParallelComputationGraphView const &g; - CostEstimator const &cost_estimator; - MachineSpecification const &resource; - std::unordered_map given_machine_views; - std::unordered_map frontier_machine_views; - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views; + CostEstimator cost_estimator; + std::function(Operator const &, + MachineSpecification const &)> + allowed_machine_views; OptimalCostCache &cached_subgraph_costs; - template - OptimalCostResult operator()(T const &t) const { - OptimalCostState state{ - t, resource , given_machine_views, frontier_machine_views}; - optional cached_result = - cached_subgraph_costs.load(state); + struct OptimalCostFunctor { + OptimalCostFunctor( + MachineMappingSearcher *searcher, + SubParallelComputationGraphView const &g, + MachineSpecification resource, + std::unordered_map given_machine_views, + std::unordered_map frontier_machine_views) + : searcher(searcher), g(g), resource(resource), + given_machine_views(given_machine_views), + frontier_machine_views(frontier_machine_views) {} + + MachineMappingSearcher *searcher; + SubParallelComputationGraphView const &g; + MachineSpecification resource; + std::unordered_map given_machine_views; + std::unordered_map frontier_machine_views; + + template + OptimalCostResult operator()(T const &t) { + OptimalCostState state{ + t, resource, given_machine_views, frontier_machine_views}; + optional cached_result = + searcher->cached_subgraph_costs.load(state); + + if (cached_result) { + return cached_result.value(); + } + OptimalCostResult result = searcher->optimal_cost( + t, g, resource, given_machine_views, frontier_machine_views); - if (cached_result) { - return cached_result.value(); + searcher->cached_subgraph_costs.save(state, result); + return result; } - OptimalCostResult result = this->optimal_cost(t); - - cached_subgraph_costs.save(state, result); - return result; + }; + + OptimalCostResult + optimal_cost(SubParallelComputationGraphView const &g, + MachineSpecification resource, + SerialParallelDecomposition const &sp_decomposition) { + return visit(OptimalCostFunctor(this, g, resource, {}, {}), + sp_decomposition); } - OptimalCostResult optimal_cost(Serial const &serial) const { + OptimalCostResult optimal_cost( + Serial const &serial, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { + auto decomposed = decompose(serial); SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; @@ -210,28 +231,30 @@ struct OptimalCost { new_frontier_machine_views.emplace(split_edge, mv); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( - visit(OptimalCost(pre_graph, - cost_estimator, - resource, - given_machine_views, - new_frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + pre_graph, + resource, + given_machine_views, + new_frontier_machine_views), pre_decompn), - visit(OptimalCost(post_graph, - cost_estimator, - resource, - new_given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + post_graph, + resource, + new_given_machine_views, + frontier_machine_views), post_decompn))); } return optimal_result; } - OptimalCostResult optimal_cost(Parallel const ¶llel) const { + OptimalCostResult optimal_cost( + Parallel const ¶llel, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { auto decomposed = decompose(parallel); SerialParallelDecomposition decompn1 = decomposed.first; SerialParallelDecomposition decompn2 = decomposed.second; @@ -243,48 +266,46 @@ struct OptimalCost { g, graph_split.second); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - visit(OptimalCost(g1, - cost_estimator, - resource, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g1, + resource, + given_machine_views, + frontier_machine_views), decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g2, + resource, + given_machine_views, + frontier_machine_views), decompn2)); for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime(optimal_result, OptimalCostResult::parallel_combine( - visit(OptimalCost(g1, - cost_estimator, - resource_split.first, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g1, + resource_split.first, + given_machine_views, + frontier_machine_views), decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource_split.second, - given_machine_views, - frontier_machine_views, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g2, + resource_split.second, + given_machine_views, + frontier_machine_views), decompn2))); } return optimal_result; } - OptimalCostResult optimal_cost(Node const &node) const { + OptimalCostResult optimal_cost( + Node const &node, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { if (contains_key(given_machine_views, node)) { assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); @@ -315,15 +336,10 @@ OptimalCostResult OptimalCostCache &cached_subgraph_costs) { SerialParallelDecomposition sp_decomposition = get_serial_parallel_decomposition(g); - SubParallelComputationGraph subpcg = pcg_to_subpcg(g); - return visit(OptimalCost(subpcg, - cost_estimator, - resources, - std::unordered_map{}, - std::unordered_map{}, - allowed_machine_views, - cached_subgraph_costs), - sp_decomposition); + SubParallelComputationGraphView subpcg = pcg_to_subpcg(g); + MachineMappingSearcher searcher( + cost_estimator, allowed_machine_views, cached_subgraph_costs); + return searcher.optimal_cost(subpcg, resources, sp_decomposition); } } // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 3363aecc2f..9fcde4dcca 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -13,7 +13,7 @@ bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { * Gets all substitutions applicable to a PCG */ std::unordered_set - get_all_substitutions(ParallelComputationGraph const &pcg) { + get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } @@ -37,7 +37,7 @@ Strategy ParallelComputationGraph pcg = cg_to_pcg(cg); - std::unordered_set subs = get_all_substitutions(pcg); + std::unordered_set subs = get_all_applicable_substitutions(pcg); OptimalCostCache cached_subgraph_costs; DeduplicatedPriorityQueue, StrategyRuntimeCmp> @@ -93,7 +93,7 @@ size_t hash::operator()(FlexFlow::Strategy const &s) const { size_t h = 0; hash_combine(h, s.pcg); - // hash_combine(h, s.machine_mapping); + hash_combine(h, s.machine_mapping); hash_combine(h, s.runtime); return h; diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 8ad65a4488..430da61ff9 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -109,7 +109,7 @@ template std::vector values(C const &c); template -std::unordered_set> +std::unordered_set> items(C const &c); template diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 679586ba69..1d0151c38a 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -228,7 +228,7 @@ std::vector values(C const &c) { } template -std::unordered_set> +std::unordered_set> items(C const &c) { return {c.begin(), c.end()}; } diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index ddf5b00355..c44cb88b61 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -6,6 +6,8 @@ #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include + namespace FlexFlow { template @@ -26,6 +28,12 @@ struct already_has_ostream_operator : std::true_type {}; template <> struct already_has_ostream_operator : std::true_type {}; +template <> +struct already_has_ostream_operator> : std::true_type {}; + +template <> +struct already_has_ostream_operator : std::true_type {}; + // This will create an error /* template diff --git a/lib/utils/include/utils/hash-utils.h b/lib/utils/include/utils/hash-utils.h index 923c8df840..d56ff34644 100644 --- a/lib/utils/include/utils/hash-utils.h +++ b/lib/utils/include/utils/hash-utils.h @@ -4,6 +4,8 @@ #include "containers.h" #include "hash-utils-core.h" +using namespace FlexFlow; + namespace std { template struct hash> { @@ -18,7 +20,7 @@ struct hash> { template struct hash> { size_t operator()(std::unordered_map const &m) const { - return get_std_hash(items(m)); + return get_std_hash(::FlexFlow::items(m)); } }; diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc new file mode 100644 index 0000000000..f0d907b741 --- /dev/null +++ b/lib/utils/test/src/test_hash.cc @@ -0,0 +1,18 @@ +#include "test/utils/doctest.h" +#include "utils/hash-utils.h" + +using namespace FlexFlow; + +TEST_CASE("hash:unordered_map") { + std::unordered_map map1{{1, 2}}; + std::unordered_map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({1, 2}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); +} From d8bbcb883103c9ca046ff700be9e1655f80e4892 Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 27 Feb 2024 16:34:09 -0500 Subject: [PATCH 11/37] minor fix --- lib/compiler/include/compiler/unity_algorithm.h | 13 ++----------- lib/compiler/src/unity_algorithm.cc | 14 -------------- lib/compiler/test/src/test_labelled_open_graph.cc | 2 -- lib/compiler/test/src/test_optimal_cost.cc | 1 + lib/pcg/include/pcg/operator.h | 2 +- lib/pcg/src/operator.cc | 4 ---- .../include/utils/graph/labelled/node_labelled.h | 9 ++++++--- .../utils/graph/labelled/node_labelled_open.h | 9 ++++++--- .../include/utils/graph/labelled/output_labelled.h | 9 ++++++--- .../utils/graph/labelled/output_labelled_open.h | 9 ++++++--- .../utils/graph/labelled/standard_labelled.h | 9 ++++++--- lib/utils/src/graph/digraph.cc | 7 ++++--- lib/utils/src/graph/multidigraph.cc | 7 ++++--- lib/utils/src/graph/node.cc | 4 ++-- lib/utils/src/graph/open_graphs.cc | 14 +++++++------- lib/utils/src/graph/undirected.cc | 6 +++--- 16 files changed, 54 insertions(+), 65 deletions(-) diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index a87bddcc3a..7d7a7a74dc 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -14,6 +14,8 @@ struct Strategy { req runtime; }; +FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); + struct StrategyRuntimeCmp { bool operator()(Strategy const &, Strategy const &); }; @@ -36,15 +38,4 @@ Strategy } // namespace FlexFlow -VISITABLE_STRUCT(FlexFlow::Strategy, pcg, machine_mapping, runtime); - -namespace std { - -template <> -struct hash { - size_t operator()(FlexFlow::Strategy const &) const; -}; - -} // namespace std - #endif diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 9fcde4dcca..c9666851db 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -86,17 +86,3 @@ Strategy } } // namespace FlexFlow - -namespace std { - -size_t hash::operator()(FlexFlow::Strategy const &s) const { - size_t h = 0; - - hash_combine(h, s.pcg); - hash_combine(h, s.machine_mapping); - hash_combine(h, s.runtime); - - return h; -} - -} // namespace std diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 1cae9a0cd1..a360d86ee7 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -7,8 +7,6 @@ using namespace FlexFlow; TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { auto g = OpenMultiDiGraph::create(); - int t0 = 100000; - Node n0 = g.add_node(); Node n1 = g.add_node(); Node n2 = g.add_node(); diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index a6cd88a006..c5f74ff392 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -4,6 +4,7 @@ using namespace FlexFlow; +// Rapidcheck infrastructures for graphs does not work for now /* Tests whether optimal_cost can give a valid result given random PCG, trivial allowed machine views, trivial cost estimator and random machine specification. diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index 3eb7fb2a43..bb9a4cf5e4 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -20,7 +20,7 @@ struct Operator { FF_VISITABLE_STRUCT(Operator, attrs, name); -static_assert(is_well_behaved_value_type::value, ""); +static_assert(is_well_behaved_value_type::value); } // namespace FlexFlow diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc index 81e7326a76..9d36ae1b25 100644 --- a/lib/pcg/src/operator.cc +++ b/lib/pcg/src/operator.cc @@ -2,10 +2,6 @@ namespace FlexFlow { -// Operator::Operator(PCGOperatorAttrs const &attrs, -// std::optional const &name) -// : attrs(attrs), name(name) {} - Operator::operator PCGOperatorAttrs() const { return attrs; } diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index ded049f224..1ecd87226c 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -54,7 +54,8 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -116,11 +117,13 @@ struct NodeLabelledMultiDiGraph : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index fab6695070..2162ee0384 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,7 +55,8 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; @@ -119,11 +120,13 @@ struct NodeLabelledOpenMultiDiGraph : GraphView(ptr), nl(nl) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index f3cf14022b..882fca8df0 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -61,7 +61,8 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; @@ -144,11 +145,13 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 1c1b28c6d6..23dd9c190c 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -73,7 +73,8 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; @@ -184,11 +185,13 @@ struct OutputLabelledOpenMultiDiGraph output_labelling(ol) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t node_labelling; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 8af69e18fc..3c69d62ae9 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -60,7 +60,8 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -131,11 +132,13 @@ struct LabelledMultiDiGraph : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } cow_ptr_t nl; diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index bdfe5ff599..dda9eef5e0 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -14,7 +14,8 @@ std::unordered_set } IDiGraphView const &DiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } Node DiGraph::add_node() { @@ -47,11 +48,11 @@ std::unordered_set } IDiGraph &DiGraph::get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); } IDiGraph const &DiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 771e01e573..99a7ea86fa 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -24,7 +24,7 @@ std::unordered_set } IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -66,11 +66,12 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast( + GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index 72caa3136e..9854afffbf 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -53,11 +53,11 @@ std::unordered_set Graph::query_nodes(NodeQuery const &q) const { } IGraph const &Graph::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); + return *std::reinterpret_pointer_cast(GraphView::ptr.get()); } IGraph &Graph::get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 8355713506..c32ff6ded5 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -56,7 +56,7 @@ NodePort OpenMultiDiGraph::add_node_port() { } IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } @@ -77,7 +77,7 @@ std::unordered_set } IUpwardOpenMultiDiGraphView const &UpwardOpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -107,12 +107,12 @@ std::unordered_set UpwardOpenMultiDiGraph::query_edges( } IUpwardOpenMultiDiGraph const &UpwardOpenMultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } IUpwardOpenMultiDiGraph &UpwardOpenMultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } @@ -129,7 +129,7 @@ std::unordered_set IDownwardOpenMultiDiGraphView const & DownwardOpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } @@ -165,7 +165,7 @@ std::unordered_set } IDownwardOpenMultiDiGraph &DownwardOpenMultiDiGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index b1e8be7f14..ce42cfe22c 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -26,12 +26,12 @@ void UndirectedGraph::remove_edge(UndirectedEdge const &e) { } IUndirectedGraph const &UndirectedGraph::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } IUndirectedGraph &UndirectedGraph::get_ptr() { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get_mutable()); } @@ -56,7 +56,7 @@ std::unordered_set } IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } From 09d3152ef80177118cd1ae51111c723f4f7482c7 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 28 Feb 2024 15:10:01 -0500 Subject: [PATCH 12/37] fix variant issue --- lib/utils/include/utils/containers.decl.h | 2 ++ lib/utils/include/utils/containers.h | 10 +++++++++ lib/utils/include/utils/variant.h | 10 ++++----- lib/utils/src/graph/algorithms.cc | 25 ++++++++--------------- lib/utils/src/graph/serialparallel.cc | 4 ++-- lib/utils/src/graph/views.cc | 11 +++++----- 6 files changed, 33 insertions(+), 29 deletions(-) diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 430da61ff9..fd35afe3fc 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -293,6 +293,8 @@ T reversed(T const &t); template std::vector value_all(std::vector> const &v); +template +std::unordered_set value_all(std::unordered_set> const &v); template std::vector subvec(std::vector const &v, diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 1d0151c38a..99c29564fb 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -675,6 +675,16 @@ std::vector value_all(std::vector> const &v) { }); } +template +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [](optional const &element) { + return unwrap(element, [] { + throw mk_runtime_error( + "Encountered element without value in call to value_all"); + }); + }); +} + template std::vector subvec(std::vector const &v, optional const &maybe_start, diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index bb78719c9e..132b7e66f4 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -58,7 +58,7 @@ struct elements_satisfy> : elements_satisfy_impl {}; template -struct is_in_variant; +struct is_in_variant : std::false_type {}; template struct is_in_variant> : std::true_type {}; template @@ -182,7 +182,7 @@ auto widen(Container const &c) -> decltype(transform( template < typename VariantOut, typename VariantIn, - typename = std::enable_if::value>> + typename = std::enable_if_t::value>> optional narrow(VariantIn const &v) { return visit(VariantNarrowFunctor{}, v); } @@ -191,7 +191,7 @@ template < typename VariantOut, typename Container, typename VariantIn = typename Container::value_type, - typename = std::enable_if::value>> + typename = std::enable_if_t::value>> auto narrow(Container const &c) -> decltype(transform( c, std::declval(VariantIn const &)>>())) { @@ -201,7 +201,7 @@ auto narrow(Container const &c) -> decltype(transform( template ::value>> + typename = std::enable_if_t::value>> auto narrow(Container const &c) { return transform(c, [](VariantIn const &e) { return get(e); }); } @@ -210,7 +210,7 @@ template , VariantIn>::value>> optional> narrow(VariantIn const &v) { return visit(VariantNarrowFunctor>{}, v); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 3b9877f71b..1667ddfce8 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -9,6 +9,7 @@ #include "utils/graph/traversal.h" #include "utils/graph/undirected.h" #include "utils/graph/views.h" +#include "utils/variant.h" #include #include #include @@ -256,7 +257,7 @@ std::unordered_set get_node_edges(UndirectedGraphView const &g, std::unordered_set get_outputs(MultiDiGraphView const &g) { return transform(get_edges(g), [&](MultiDiEdge const &e) -> MultiDiOutput { - return MultiDiOutput(e); + return static_cast(e); }); } @@ -333,37 +334,27 @@ std::unordered_map> std::unordered_set get_outgoing_edges(OpenMultiDiGraphView const &g, Node const &n) { - return transform(g.query_edges(OpenMultiDiEdgeQuery( + return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( InputMultiDiEdgeQuery::none(), MultiDiEdgeQuery::all().with_src_nodes({n}), - OutputMultiDiEdgeQuery::all().with_src_nodes({n}))), - [](OpenMultiDiEdge const &e) { - return narrow(e).value(); - }); + OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); } std::unordered_set get_incoming_edges(OpenMultiDiGraphView const &g, Node const &n) { - return transform(g.query_edges(OpenMultiDiEdgeQuery( + return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( InputMultiDiEdgeQuery::all().with_dst_nodes({n}), MultiDiEdgeQuery::all().with_dst_nodes({n}), - OutputMultiDiEdgeQuery::none())), - [](OpenMultiDiEdge const &e) { - return narrow(e).value(); - }); + OutputMultiDiEdgeQuery::none())))); } std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { - return transform( - g.query_edges(OutputMultiDiEdgeQuery::all()), - [](OpenMultiDiEdge const &e) { return get(e); }); + return narrow(g.query_edges(OutputMultiDiEdgeQuery::all())); } std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { - return transform( - g.query_edges(InputMultiDiEdgeQuery::all()), - [](OpenMultiDiEdge const &e) { return get(e); }); + return narrow(g.query_edges(InputMultiDiEdgeQuery::all())); } std::unordered_map> diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 41ecf3c436..8b179d31de 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -178,11 +178,11 @@ struct ToFinalAST { variant operator()(SplitASTNode const &node) { if (node.type == SplitType::SERIAL) { return Serial{transform(node.children, [](SplitAST const &s) { - return narrow(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } else { return Parallel{transform(node.children, [](SplitAST const &s) { - return narrow(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index dc823f7da4..a1308cffbb 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -445,9 +445,10 @@ std::unordered_set OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), - inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)), - outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} + : g(g), nodes(nodes) { + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); + } std::unordered_set OpenMultiDiSubgraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { @@ -470,7 +471,7 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); } UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { @@ -496,7 +497,7 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); } std::unordered_set From a150d3a90536ba6583276d52787b35fc15ba7f1d Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 28 Feb 2024 15:15:00 -0500 Subject: [PATCH 13/37] fmt --- lib/utils/src/graph/algorithms.cc | 20 +++++++++++--------- lib/utils/src/graph/views.cc | 6 +++--- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 1667ddfce8..777d3d55d2 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -334,23 +334,25 @@ std::unordered_map> std::unordered_set get_outgoing_edges(OpenMultiDiGraphView const &g, Node const &n) { - return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), - MultiDiEdgeQuery::all().with_src_nodes({n}), - OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); + return value_all( + narrow(g.query_edges(OpenMultiDiEdgeQuery( + InputMultiDiEdgeQuery::none(), + MultiDiEdgeQuery::all().with_src_nodes({n}), + OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); } std::unordered_set get_incoming_edges(OpenMultiDiGraphView const &g, Node const &n) { - return value_all(narrow(g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::all().with_dst_nodes({n}), - MultiDiEdgeQuery::all().with_dst_nodes({n}), - OutputMultiDiEdgeQuery::none())))); + return value_all(narrow(g.query_edges( + OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery::all().with_dst_nodes({n}), + MultiDiEdgeQuery::all().with_dst_nodes({n}), + OutputMultiDiEdgeQuery::none())))); } std::unordered_set get_open_outputs(OpenMultiDiGraphView const &g) { - return narrow(g.query_edges(OutputMultiDiEdgeQuery::all())); + return narrow( + g.query_edges(OutputMultiDiEdgeQuery::all())); } std::unordered_set get_open_inputs(OpenMultiDiGraphView const &g) { diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index a1308cffbb..af15b0d6aa 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -446,9 +446,9 @@ std::unordered_set OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); - this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); - } + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); +} std::unordered_set OpenMultiDiSubgraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { From 2eb3fdfd2ab9efc7ff217c715303e69cd6da5f9a Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 11 Mar 2024 18:50:35 -0400 Subject: [PATCH 14/37] fix --- lib/compiler/test/src/test_generator.h | 313 +++++++++--------- lib/compiler/test/src/test_machine_mapping.cc | 34 +- lib/compiler/test/src/test_unity_algorithm.cc | 39 +-- 3 files changed, 194 insertions(+), 192 deletions(-) diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index 6566c8c2de..c14743347a 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -8,161 +8,162 @@ using namespace FlexFlow; -/* - Generates computation graphs with trivial layers and tensors, which are used - for tests focusing on graph structures. -*/ -ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { - return materialize_output_labelled_multidigraph_view( - ViewMultiDiGraphAsOutputLabelled( - g, - [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, - [](Tensor(MultiDiOutput const &)) { - return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; - })); -} - -/* - Generates parallel computation graphs with trivial layers and tensors, which - are used for tests focusing on graph structures. -*/ -ParallelComputationGraph - test_parallel_computation_graph(MultiDiGraphView const &g) { - return materialize_output_labelled_multidigraph_view( - ViewMultiDiGraphAsOutputLabelled( - g, - [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, - [](Operator(MultiDiOutput const &)) { - return ParallelTensor(ParallelTensorDims(TensorDims({})), - DataType::FLOAT); - })); -} - -rc::Gen small_integer_generator() { - return rc::gen::inRange(1, 4); -} - -namespace rc { - -Gen serialParallelMultiDiGraph() { - return gen::map(gen::arbitrary(), - multidigraph_from_sp_decomposition); -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::map(gen::cast(serialParallelMultiDiGraph()), - test_computataion_graph); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::map(gen::cast(serialParallelMultiDiGraph()), - test_parallel_computation_graph); - } -}; - -template <> -struct Arbitrary> { - static Gen> arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node - ? gen::cast>(gen::arbitrary()) - : gen::cast>(gen::arbitrary()); - }); - } -}; - -template <> -struct Arbitrary> { - static Gen> arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node - ? gen::cast>(gen::arbitrary()) - : gen::cast>( - gen::arbitrary()); - }); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&Serial::children, - gen::container>>( - gen::arbitrary>()))); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&Parallel::children, - gen::container>>( - gen::arbitrary>()))); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_serial) { - return is_serial ? gen::construct( - gen::arbitrary()) - : gen::construct( - gen::arbitrary()); - }); - } -}; - -template -struct Arbitrary { - static Gen< - std::enable_if, Tag>::value>::type> - arbitrary() { - return gen::construct(gen::arbitrary()); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::apply(make_1d_machine_view, - gen::arbitrary, - gen::arbitrary, - small_integer_generator()); - } -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&MachineMapping::machine_views, - gen::container>( - gen::arbitrary(), gen::arbitrary()))); - } -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), - gen::set(&MachineSpecification::inter_node_bandwidth, - gen::nonZero()), - gen::set(&MachineSpecification::intra_node_bandwidth, - gen::nonZero())); - } -} - -} // namespace rc +// Rapidcheck does not work for now +// /* +// Generates computation graphs with trivial layers and tensors, which are used +// for tests focusing on graph structures. +// */ +// ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { +// return materialize_output_labelled_multidigraph_view( +// ViewMultiDiGraphAsOutputLabelled( +// g, +// [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, +// [](Tensor(MultiDiOutput const &)) { +// return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; +// })); +// } + +// /* +// Generates parallel computation graphs with trivial layers and tensors, which +// are used for tests focusing on graph structures. +// */ +// ParallelComputationGraph +// test_parallel_computation_graph(MultiDiGraphView const &g) { +// return materialize_output_labelled_multidigraph_view( +// ViewMultiDiGraphAsOutputLabelled( +// g, +// [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, +// [](Operator(MultiDiOutput const &)) { +// return ParallelTensor(ParallelTensorDims(TensorDims({})), +// DataType::FLOAT); +// })); +// } + +// rc::Gen small_integer_generator() { +// return rc::gen::inRange(1, 4); +// } + +// namespace rc { + +// Gen serialParallelMultiDiGraph() { +// return gen::map(gen::arbitrary(), +// multidigraph_from_sp_decomposition); +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// test_computataion_graph); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// test_parallel_computation_graph); +// } +// }; + +// template <> +// struct Arbitrary> { +// static Gen> arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_node) { +// return is_node +// ? gen::cast>(gen::arbitrary()) +// : gen::cast>(gen::arbitrary()); +// }); +// } +// }; + +// template <> +// struct Arbitrary> { +// static Gen> arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_node) { +// return is_node +// ? gen::cast>(gen::arbitrary()) +// : gen::cast>( +// gen::arbitrary()); +// }); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&Serial::children, +// gen::container>>( +// gen::arbitrary>()))); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&Parallel::children, +// gen::container>>( +// gen::arbitrary>()))); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_serial) { +// return is_serial ? gen::construct( +// gen::arbitrary()) +// : gen::construct( +// gen::arbitrary()); +// }); +// } +// }; + +// template +// struct Arbitrary { +// static Gen< +// std::enable_if, Tag>::value>::type> +// arbitrary() { +// return gen::construct(gen::arbitrary()); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::apply(make_1d_machine_view, +// gen::arbitrary, +// gen::arbitrary, +// small_integer_generator()); +// } +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&MachineMapping::machine_views, +// gen::container>( +// gen::arbitrary(), gen::arbitrary()))); +// } +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), +// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), +// gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), +// gen::set(&MachineSpecification::inter_node_bandwidth, +// gen::nonZero()), +// gen::set(&MachineSpecification::intra_node_bandwidth, +// gen::nonZero())); +// } +// } + +// } // namespace rc #endif diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc index 779f8134d9..b2abc6929d 100644 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -1,21 +1,21 @@ -#include "doctest/doctest.h" -#include "test_generator.h" +// #include "doctest/doctest.h" +// #include "test_generator.h" -TEST_CASE("MachineMapping::combine") { - rc::check([](MachineMapping const &m0, MachineMapping const &m1) { - RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); +// TEST_CASE("MachineMapping::combine") { +// rc::check([](MachineMapping const &m0, MachineMapping const &m1) { +// RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); - MachineMapping comb = MachineMapping::combine(m0, m1); +// MachineMapping comb = MachineMapping::combine(m0, m1); - RC_ASSERT(comb.machine_views.size() == - m0.machine_views.size() + m1.machine_views.size()); - RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); - RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); - }); -} +// RC_ASSERT(comb.machine_views.size() == +// m0.machine_views.size() + m1.machine_views.size()); +// RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); +// RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); +// }); +// } -TEST_CASE("OptimalCostResult::infinity") { - rc::check([](OptimalCostResult const &c) { - RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); - }); -} +// TEST_CASE("OptimalCostResult::infinity") { +// rc::check([](OptimalCostResult const &c) { +// RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); +// }); +// } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index 6a0131dd77..cceecae831 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -2,22 +2,23 @@ #include "test_cost_estimator.h" #include "test_generator.h" -TEST_CASE("graph_optimize") { - rc::check([](ComputationGraph const &g, - float alpha, - int budget, - float threshold, - int max_num_ops) { - Strategy s = graph_optimize( - g, - TestCostEstimator{}, - MachineSpecification{1, 1, 4, 0.1, 0.2}, - [](Operator const &, MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(0, 1, 1)}; - }, - OptimizerConfig{alpha, budget, threshold, max_num_ops}); - RC_ASSERT(get_nodes(s.pcg).size() > 0); - RC_ASSERT(s.machine_mapping.runtime > 0); - RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); - }); -} +// Rapidcheck does not work for now +// TEST_CASE("graph_optimize") { +// rc::check([](ComputationGraph const &g, +// float alpha, +// int budget, +// float threshold, +// int max_num_ops) { +// Strategy s = graph_optimize( +// g, +// TestCostEstimator{}, +// MachineSpecification{1, 1, 4, 0.1, 0.2}, +// [](Operator const &, MachineSpecification const &) { +// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; +// }, +// OptimizerConfig{alpha, budget, threshold, max_num_ops}); +// RC_ASSERT(get_nodes(s.pcg).size() > 0); +// RC_ASSERT(s.machine_mapping.runtime > 0); +// RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); +// }); +// } From 7598a923848588234262a61a83a4fa8bd0377f33 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 11 Mar 2024 18:55:34 -0400 Subject: [PATCH 15/37] fmt --- lib/compiler/test/src/test_generator.h | 29 +++++++++++-------- lib/compiler/test/src/test_unity_algorithm.cc | 3 +- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index c14743347a..d6b8222968 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -10,8 +10,8 @@ using namespace FlexFlow; // Rapidcheck does not work for now // /* -// Generates computation graphs with trivial layers and tensors, which are used -// for tests focusing on graph structures. +// Generates computation graphs with trivial layers and tensors, which are +// used for tests focusing on graph structures. // */ // ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { // return materialize_output_labelled_multidigraph_view( @@ -24,8 +24,8 @@ using namespace FlexFlow; // } // /* -// Generates parallel computation graphs with trivial layers and tensors, which -// are used for tests focusing on graph structures. +// Generates parallel computation graphs with trivial layers and tensors, +// which are used for tests focusing on graph structures. // */ // ParallelComputationGraph // test_parallel_computation_graph(MultiDiGraphView const &g) { @@ -53,7 +53,8 @@ using namespace FlexFlow; // template <> // struct Arbitrary { // static Gen arbitrary() { -// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// return +// gen::map(gen::cast(serialParallelMultiDiGraph()), // test_computataion_graph); // } // }; @@ -61,7 +62,8 @@ using namespace FlexFlow; // template <> // struct Arbitrary { // static Gen arbitrary() { -// return gen::map(gen::cast(serialParallelMultiDiGraph()), +// return +// gen::map(gen::cast(serialParallelMultiDiGraph()), // test_parallel_computation_graph); // } // }; @@ -72,7 +74,8 @@ using namespace FlexFlow; // return gen::mapcat(gen::arbitrary(), [](bool is_node) { // return is_node // ? gen::cast>(gen::arbitrary()) -// : gen::cast>(gen::arbitrary()); +// : gen::cast>(gen::arbitrary()); // }); // } // }; @@ -124,8 +127,8 @@ using namespace FlexFlow; // template // struct Arbitrary { // static Gen< -// std::enable_if, Tag>::value>::type> -// arbitrary() { +// std::enable_if, +// Tag>::value>::type> arbitrary() { // return gen::construct(gen::arbitrary()); // } // }; @@ -146,7 +149,8 @@ using namespace FlexFlow; // return gen::build( // gen::set(&MachineMapping::machine_views, // gen::container>( -// gen::arbitrary(), gen::arbitrary()))); +// gen::arbitrary(), +// gen::arbitrary()))); // } // } @@ -155,8 +159,9 @@ using namespace FlexFlow; // static Gen arbitrary() { // return gen::build( // gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), +// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, +// 64)), gen::set(&MachineSpecification::num_gpus_per_node, +// gen::inRange(1, 16)), // gen::set(&MachineSpecification::inter_node_bandwidth, // gen::nonZero()), // gen::set(&MachineSpecification::intra_node_bandwidth, diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index cceecae831..c39b3ef14f 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -14,7 +14,8 @@ // TestCostEstimator{}, // MachineSpecification{1, 1, 4, 0.1, 0.2}, // [](Operator const &, MachineSpecification const &) { -// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; +// return std::unordered_set{make_1d_machine_view(0, 1, +// 1)}; // }, // OptimizerConfig{alpha, budget, threshold, max_num_ops}); // RC_ASSERT(get_nodes(s.pcg).size() > 0); From 05c8336ef7109f112c67ec838540fb8a1b06dfb3 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 13 Mar 2024 20:56:15 -0400 Subject: [PATCH 16/37] fix --- lib/compiler/src/graph_utils.cc | 2 +- lib/compiler/src/machine_mapping.cc | 23 ++----------------- lib/op-attrs/src/attention.cc | 7 ++++++ lib/op-attrs/src/embedding.cc | 8 ++++++- .../src/parallel_dim_mapping_record_solver.cc | 8 +++++++ lib/pcg/src/strided_rectangle.cc | 4 ++++ .../include/utils/graph/labelled/open_views.h | 6 ++--- .../utils/graph/labelled/output_labelled.h | 18 +++++++-------- lib/utils/include/utils/graph/views.h | 16 ++++++------- lib/utils/src/graph/open_graphs.cc | 2 +- lib/utils/src/graph/serialparallel.cc | 2 +- 11 files changed, 50 insertions(+), 46 deletions(-) diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 069ae4a41f..3c6e44216b 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -54,7 +54,7 @@ std::unordered_map } } - assert(result.size() == get_edges(pcg).size()); + assert(result.size() == get_edges(pcg.value()).size()); return result; } diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 5ce988b951..b48e200c15 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -94,6 +94,7 @@ float estimate_cost(SubParallelComputationGraphView const &g, MachineMapping const &device_mapping, std::unordered_map const &frontier_machine_views) { + // TODO: Consider parallelism float cost = 0; for (Node const &node : get_nodes(g)) { std::unordered_set incoming_edges = @@ -106,26 +107,6 @@ float estimate_cost(SubParallelComputationGraphView const &g, cost += estimator.estimate_cost( g.at(node).attrs, inputs, device_mapping.machine_views.at(node)); } - - for (OpenMultiDiEdge const &edge : get_edges(g)) { - if (holds_alternative(edge)) { - cost += estimator.estimate_cost( - g.at(edge).get_shape(), - frontier_machine_views.at(edge), - device_mapping.machine_views.at(get(edge).dst)); - } else if (holds_alternative(edge)) { - cost += estimator.estimate_cost( - g.at(edge).get_shape(), - device_mapping.machine_views.at(get(edge).src), - frontier_machine_views.at(edge)); - } else { - assert(holds_alternative(edge)); - cost += estimator.estimate_cost( - g.at(edge).get_shape(), - device_mapping.machine_views.at(get(edge).src), - device_mapping.machine_views.at(get(edge).dst)); - } - } return cost; } @@ -308,7 +289,7 @@ struct MachineMappingSearcher { &frontier_machine_views) { if (contains_key(given_machine_views, node)) { assert(contains(allowed_machine_views(g.at(node), resource), - source_machine_view.value())); + given_machine_views.at(node))); MachineMapping mv_map{given_machine_views}; return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), mv_map}; diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/attention.cc index 4b6c53897c..2c1500a477 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/attention.cc @@ -91,7 +91,14 @@ TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, static_cast(value_shape)); return get_tensor_shape_unsafe(parallel_shape); } +TensorShape get_output_shape(MultiHeadAttentionAttrs const &, + MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} +int get_oSize(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} } // namespace FlexFlow // Tensor FFModel::multihead_attention(const Tensor query, diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc index 02cbfaa031..56014fcc67 100644 --- a/lib/op-attrs/src/embedding.cc +++ b/lib/op-attrs/src/embedding.cc @@ -1,3 +1,9 @@ #include "op-attrs/ops/embedding.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc index 68686393f5..500119241d 100644 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc +++ b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc @@ -351,4 +351,12 @@ void construct_output_parallel_dims( /* return solution; */ /* } */ +ParallelDimMappingSolution solve_parallel_dim_mappings( + std::vector const &mappings, + std::vector const &input, + int numWeights, + int numOutputs) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 7f612b743b..27ef9a7f5b 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -30,4 +30,8 @@ side_size_t StridedRectangleSide::get_size() const { NOT_IMPLEMENTED(); } +size_t StridedRectangle::num_dims() const { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index a24c2b940b..494d8d9f9d 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -48,8 +48,8 @@ struct OutputLabelledOpenMultiDiSubgraphView } private: - OutputLabelledOpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OutputLabelledOpenMultiDiGraphView g; + std::unordered_set nodes; }; template @@ -86,7 +86,7 @@ struct ViewOutputLabelledAsOutputLabelledOpen } private: - OutputLabelledMultiDiGraphView const &g; + OutputLabelledMultiDiGraphView g; }; template diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 882fca8df0..9c65db4daa 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -31,20 +31,19 @@ struct OutputLabelledMultiDiGraphView OutputLabelledMultiDiGraphView & operator=(OutputLabelledMultiDiGraphView const &) = default; - virtual NodeLabel const &at(Node const &n) const { + NodeLabel const &at(Node const &n) const { return get_ptr().at(n); } - virtual OutputLabel const &at(MultiDiOutput const &o) const { + OutputLabel const &at(MultiDiOutput const &o) const { return get_ptr().at(o); } - virtual std::unordered_set query_nodes(NodeQuery const &q) const { + std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - virtual std::unordered_set - query_edges(MultiDiEdgeQuery const &q) const { + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } @@ -93,7 +92,7 @@ struct OutputLabelledMultiDiGraph return nl.get_mutable()->get_label(n); } - NodeLabel const &at(Node const &n) const override { + NodeLabel const &at(Node const &n) const { return nl->get_label(n); } @@ -113,16 +112,15 @@ struct OutputLabelledMultiDiGraph return ol.get_mutable()->get_label(o); } - OutputLabel const &at(MultiDiOutput const &o) const override { + OutputLabel const &at(MultiDiOutput const &o) const { return ol->get_label(o); } - std::unordered_set query_nodes(NodeQuery const &q) const override { + std::unordered_set query_nodes(NodeQuery const &q) const { return get_ptr().query_nodes(q); } - std::unordered_set - query_edges(MultiDiEdgeQuery const &q) const override { + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return get_ptr().query_edges(q); } diff --git a/lib/utils/include/utils/graph/views.h b/lib/utils/include/utils/graph/views.h index 776a72e6d5..43d813bf8c 100644 --- a/lib/utils/include/utils/graph/views.h +++ b/lib/utils/include/utils/graph/views.h @@ -256,8 +256,8 @@ struct OpenMultiDiSubgraphView : public IOpenMultiDiGraphView { OpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set inputs; std::unordered_set outputs; }; @@ -274,8 +274,8 @@ struct UpwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { UpwardOpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set inputs; }; @@ -291,8 +291,8 @@ struct DownwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { DownwardOpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set outputs; }; @@ -308,8 +308,8 @@ struct ClosedMultiDiSubgraphView : public IOpenMultiDiGraphView { ClosedMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; }; UndirectedEdge to_undirected_edge(DirectedEdge const &); diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index c32ff6ded5..e0bc94ca8c 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 8b179d31de..3461e27ddf 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -19,7 +19,7 @@ Node find_sink_node(DiGraphView const &g) { optional find_bottleneck_node(DiGraphView const &g) { std::unordered_set sources = get_sources(g); - std::unordered_set sinks = get_sources(g); + std::unordered_set sinks = get_sinks(g); optional maybe_bottleneck = get_imm_post_dominator(g, sources); if (maybe_bottleneck.has_value()) { From 9345400aab5fbe5cc20fd144df63194f145da84e Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 18 Mar 2024 15:51:34 -0400 Subject: [PATCH 17/37] add more unit tests --- lib/compiler/test/CMakeLists.txt | 2 +- .../test/src/test_labelled_open_graph.cc | 190 ++++++++----- lib/compiler/test/src/test_optimal_cost.cc | 4 +- lib/substitutions/src/substitution.cc | 7 +- .../test/src/test_substitution.cc | 18 +- .../utils/graph/labelled/labelled_open.decl.h | 124 --------- .../utils/graph/labelled/labelled_open.h | 173 ------------ .../graph/labelled/labelled_open_interfaces.h | 62 ----- .../utils/graph/labelled/node_labelled.h | 48 +--- .../graph/labelled/node_labelled_interfaces.h | 36 +++ .../utils/graph/labelled/node_labelled_open.h | 53 ++-- .../utils/graph/labelled/output_labelled.h | 73 ++--- .../labelled/output_labelled_interfaces.h | 15 +- .../graph/labelled/output_labelled_open.h | 88 ++----- .../output_labelled_open_interfaces.h | 34 +++ .../utils/graph/labelled/standard_labelled.h | 59 +---- .../labelled/unordered_labelled_graphs.h | 249 ++++++++++++------ .../include/utils/graph/labelled/views.h | 8 +- .../include/utils/graph/labelled_graphs.h | 1 + lib/utils/src/graph/open_graphs.cc | 2 +- lib/utils/test/CMakeLists.txt | 24 +- lib/utils/test/src/test_cow_ptr.cc | 60 +++++ 22 files changed, 563 insertions(+), 767 deletions(-) delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open.decl.h delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open.h delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h create mode 100644 lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h create mode 100644 lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h create mode 100644 lib/utils/test/src/test_cow_ptr.cc diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index cbd7e233c0..3d35fdabfd 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,7 +2,7 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/*.cc + src/test_labelled_open_graph.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index a360d86ee7..a3b6319528 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,73 +4,141 @@ using namespace FlexFlow; -TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { - auto g = OpenMultiDiGraph::create(); +// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { +// auto g = OpenMultiDiGraph::create(); - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - Node n4 = g.add_node(); +// Node n0 = g.add_node(); +// Node n1 = g.add_node(); +// Node n2 = g.add_node(); +// Node n3 = g.add_node(); +// Node n4 = g.add_node(); + +// NodePort p0 = g.add_node_port(); +// NodePort p1 = g.add_node_port(); +// NodePort p2 = g.add_node_port(); +// NodePort p3 = g.add_node_port(); +// NodePort p4 = g.add_node_port(); +// NodePort p5 = g.add_node_port(); +// NodePort p6 = g.add_node_port(); +// NodePort p7 = g.add_node_port(); +// NodePort p8 = g.add_node_port(); +// NodePort p9 = g.add_node_port(); + +// MultiDiEdge e0{n1, p1, n0, p0}; +// MultiDiEdge e1{n2, p2, n0, p0}; +// MultiDiEdge e2{n3, p5, n1, p3}; +// MultiDiEdge e3{n3, p6, n2, p4}; +// MultiDiEdge e4{n4, p8, n3, p7}; +// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + +// g.add_edge(e0); +// g.add_edge(e1); +// g.add_edge(e2); +// g.add_edge(e3); +// g.add_edge(e4); +// g.add_edge(e5); + +// std::unordered_set node_set0{n3, n4}; + +// auto subgraph0 = get_subgraph(g, node_set0); +// auto subgraph1 = get_subgraph(g, node_set0); +// auto subgraph2 = get_subgraph(g, +// node_set0); auto subgraph3 = get_subgraph(g, +// node_set0); + +// CHECK(get_nodes(subgraph0) == node_set0); +// CHECK(get_nodes(subgraph1) == node_set0); +// CHECK(get_nodes(subgraph2) == node_set0); +// CHECK(get_nodes(subgraph3) == node_set0); + +// std::unordered_set input_set{split_edge(e2).second, +// split_edge(e3).second}; +// std::unordered_set output_set{e5}; + +// CHECK(bool(get_open_inputs(subgraph0) == input_set)); +// CHECK(bool(get_open_inputs(subgraph1) == input_set)); +// CHECK(bool(get_open_inputs(subgraph2).empty())); +// CHECK(bool(get_open_inputs(subgraph3).empty())); + +// CHECK(bool(get_open_outputs(subgraph0) == output_set)); +// CHECK(bool(get_open_outputs(subgraph1).empty())); +// CHECK(bool(get_open_outputs(subgraph2) == output_set)); +// CHECK(bool(get_open_outputs(subgraph3).empty())); + +// CHECK(bool(get_edges(subgraph0) == +// std::unordered_set{ +// split_edge(e2).second, split_edge(e3).second, e4, e5})); +// CHECK(bool(get_edges(subgraph1) == +// std::unordered_set{ +// split_edge(e2).second, split_edge(e3).second, e4})); +// CHECK(bool(get_edges(subgraph2) == +// std::unordered_set{e4, e5})); +// CHECK(bool(get_edges(subgraph3) == +// std::unordered_set{e4})); + +// CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); +// } + +// TEST_CASE("view OutputLabelledMultiDiGraph as open") { +// OutputLabelledMultiDiGraph g = +// OutputLabelledMultiDiGraph::create>(); + +// Node n0 = g.add_node(0); +// Node n1 = g.add_node(1); + +// NodePort p0 = g.add_node_port(); +// NodePort p1 = g.add_node_port(); + +// MultiDiEdge e0{n1, p1, n0, p0}; + +// g.add_edge(e0); +// g.add_output(e0, 2); + +// CHECK(get_edges(g).size() == 1); + +// OutputLabelledOpenMultiDiGraphView open_graph = +// view_output_labelled_as_output_labelled_open(g); + +// CHECK(open_graph.at(n0) == 0); +// CHECK(open_graph.at(n1) == 1); +// CHECK(open_graph.at(e0) == 2); + +// // CHECK(get_edges(open_graph).size() == 1); +// } + +TEST_CASE("OutputLabelledOpenMultiDiGraph") { + OutputLabelledOpenMultiDiGraph g = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); NodePort p0 = g.add_node_port(); NodePort p1 = g.add_node_port(); - NodePort p2 = g.add_node_port(); - NodePort p3 = g.add_node_port(); - NodePort p4 = g.add_node_port(); - NodePort p5 = g.add_node_port(); - NodePort p6 = g.add_node_port(); - NodePort p7 = g.add_node_port(); - NodePort p8 = g.add_node_port(); - NodePort p9 = g.add_node_port(); MultiDiEdge e0{n1, p1, n0, p0}; - MultiDiEdge e1{n2, p2, n0, p0}; - MultiDiEdge e2{n3, p5, n1, p3}; - MultiDiEdge e3{n3, p6, n2, p4}; - MultiDiEdge e4{n4, p8, n3, p7}; - OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - g.add_edge(e5); - - std::unordered_set node_set0{n3, n4}; - - auto subgraph0 = get_subgraph(g, node_set0); - auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, node_set0); - auto subgraph3 = get_subgraph(g, node_set0); - - CHECK(get_nodes(subgraph0) == node_set0); - CHECK(get_nodes(subgraph1) == node_set0); - CHECK(get_nodes(subgraph2) == node_set0); - CHECK(get_nodes(subgraph3) == node_set0); - - std::unordered_set input_set{split_edge(e2).second, - split_edge(e3).second}; - std::unordered_set output_set{e5}; - - CHECK(bool(get_open_inputs(subgraph0) == input_set)); - CHECK(bool(get_open_inputs(subgraph1) == input_set)); - CHECK(bool(get_open_inputs(subgraph2).empty())); - CHECK(bool(get_open_inputs(subgraph3).empty())); - - CHECK(bool(get_open_outputs(subgraph0) == output_set)); - CHECK(bool(get_open_outputs(subgraph1).empty())); - CHECK(bool(get_open_outputs(subgraph2) == output_set)); - CHECK(bool(get_open_outputs(subgraph3).empty())); - - CHECK(bool(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5})); - CHECK(bool(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4})); - CHECK(bool(get_edges(subgraph2) == - std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); + g.add_label(e0, 2); + + CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); + CHECK(get_edges(g).size() == 1); } + +// TEST_CASE("OpenMultiDiGraph") { +// OpenMultiDiGraph g = OpenMultiDiGraph::create(); + +// Node n0 = g.add_node(); +// Node n1 = g.add_node(); + +// NodePort p0 = g.add_node_port(); +// NodePort p1 = g.add_node_port(); + +// MultiDiEdge e0{n1, p1, n0, p0}; + +// g.add_edge(e0); + +// CHECK(get_edges(g).size() == 1); +// } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index c5f74ff392..9d90285870 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -30,9 +30,7 @@ allowed machine views, trivial cost estimator and random machine specification. TEST_CASE("optimal_cost_0") { auto pcg = OutputLabelledMultiDiGraph::template create< - AdjacencyMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling>(); + UnorderedOutputLabelledMultiDiGraph>(); Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n1 = pcg.add_node(Operator{ diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index f846171b62..da9f303ab8 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -413,11 +413,8 @@ SubParallelComputationGraph Substitution const &substitution, MultiDiGraphPatternMatch const &match) { SubParallelComputationGraph new_pcg = - OutputLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>(); + OutputLabelledOpenMultiDiGraph::template create< + UnorderedOutputLabelledOpenMultiDiGraph>(); bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { if (!contains_r(match.node_assignment, node)) { diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index a33e9127cc..a8f5283eda 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -19,12 +19,10 @@ TEST_CASE("apply_substitution") { ParallelTensorPattern tensor_pattern_empty{ std::vector{}}; - auto ig = - OutputLabelledOpenMultiDiGraph:: - create, - UnorderedLabelling, - UnorderedLabelling>(); + auto ig = OutputLabelledOpenMultiDiGraph:: + create>(); Node n0 = ig.add_node(operator_pattern_n0); NodePort p0 = ig.add_node_port(); InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; @@ -60,8 +58,7 @@ TEST_CASE("apply_substitution") { {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; auto og = NodeLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling>(); + UnorderedNodeLabelledOpenMultiDiGraph>(); Node n1 = og.add_node(op_ass_n1); Node n2 = og.add_node(op_ass_n2); Node n3 = og.add_node(op_ass_n3); @@ -88,10 +85,7 @@ TEST_CASE("apply_substitution") { SubParallelComputationGraph pcg = OutputLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>(); + UnorderedOutputLabelledOpenMultiDiGraph>(); Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n5 = pcg.add_node(Operator{ diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h b/lib/utils/include/utils/graph/labelled/labelled_open.decl.h deleted file mode 100644 index cdd22b7847..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H - -#include "labelled_open_interfaces.h" -#include "node_labelled.h" -#include "utils/graph/open_graphs.h" - -namespace FlexFlow { - -template -struct LabelledOpenMultiDiGraphView { -private: - using Interface = ILabelledOpenMultiDiGraphView; - -public: - LabelledOpenMultiDiGraphView() = delete; - - operator OpenMultiDiGraphView() const; - // operator MultiDiGraphView() const; - - NodeLabel const &at(Node const &n) const; - EdgeLabel const &at(MultiDiEdge const &e) const; - InputLabel const &at(InputMultiDiEdge const &e) const; - OutputLabel const &at(OutputMultiDiEdge const &e) const; - - template - static typename std::enable_if::value, - LabelledOpenMultiDiGraphView>::type - create(); - -private: - std::shared_ptr ptr; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ( - LabelledOpenMultiDiGraphView); - -template -struct LabelledOpenMultiDiGraph { -private: - using Interface = - ILabelledOpenMultiDiGraph; - -public: - LabelledOpenMultiDiGraph() = delete; - LabelledOpenMultiDiGraph(LabelledOpenMultiDiGraph const &other) = default; - LabelledOpenMultiDiGraph & - operator=(LabelledOpenMultiDiGraph const &other) = default; - - operator LabelledOpenMultiDiGraphView() const; - - operator OpenMultiDiGraphView() const; - - friend void swap(LabelledOpenMultiDiGraph &lhs, - LabelledOpenMultiDiGraph &rhs) { - using std::swap; - - swap(lhs.ptr, rhs.ptr); - } - - Node add_node(NodeLabel const &l); - NodeLabel &at(Node const &n); - - NodePort add_node_port(); - - NodeLabel const &at(Node const &n) const; - - void add_node_unsafe(Node const &n, NodeLabel const &l); - - std::unordered_set query_nodes(NodeQuery const &q) const; - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &q) const; - - void add_edge( - MultiDiEdge const &e); // We should allow adding edges without labels. For - // example, we may want to first construct a PCG - // and infer its tensor shapes later. - void add_edge(InputMultiDiEdge const &e); - void add_edge(OutputMultiDiEdge const &e); - - void add_label(MultiDiEdge const &e, EdgeLabel const &l); - void add_label(InputMultiDiEdge const &e, EdgeLabel const &l); - void add_label(OutputMultiDiEdge const &e, EdgeLabel const &l); - - void add_edge(MultiDiEdge const &e, EdgeLabel const &l); - EdgeLabel &at(MultiDiEdge const &e); - EdgeLabel const &at(MultiDiEdge const &e) const; - - void add_edge(InputMultiDiEdge const &e, InputLabel const &l); - InputLabel &at(InputMultiDiEdge const &e); - InputLabel const &at(InputMultiDiEdge const &e) const; - - void add_edge(OutputMultiDiEdge const &, OutputLabel const &); - OutputLabel &at(OutputMultiDiEdge const &); - OutputLabel const &at(OutputMultiDiEdge const &) const; - - template - static typename std::enable_if::value, - LabelledOpenMultiDiGraph>::type - create(); - -private: - LabelledOpenMultiDiGraph(cow_ptr_t ptr); - -private: - cow_ptr_t ptr; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ( - LabelledOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.h b/lib/utils/include/utils/graph/labelled/labelled_open.h deleted file mode 100644 index 58fd5416f7..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open.h +++ /dev/null @@ -1,173 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H - -#include "labelled_open.decl.h" -#include "labelled_open_interfaces.h" -#include "node_labelled.h" -#include "utils/graph/open_graph_interfaces.h" -#include "utils/graph/open_graphs.h" - -namespace FlexFlow { - -// LabelledOpenMultiDiGraphView -template -LabelledOpenMultiDiGraphView::operator OpenMultiDiGraphView() - const { - return GraphInternal::create_open_multidigraph_view(this->ptr); -} - -// template -// LabelledOpenMultiDiGraphView::operator MultiDiGraphView() const { -// return GraphInternal::create_multidigraphview(this->ptr); -// } - -template -NodeLabel const & - LabelledOpenMultiDiGraphView::at(Node const &n) const { - return this->ptr->at(n); -} - -template -EdgeLabel const &LabelledOpenMultiDiGraphView::at( - MultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -InputLabel const &LabelledOpenMultiDiGraphView::at( - InputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -OutputLabel const &LabelledOpenMultiDiGraphView::at( - OutputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -template -enable_if_t::Interface, - BaseImpl>::value, - LabelledOpenMultiDiGraphView> - LabelledOpenMultiDiGraphView::create() { - return LabelledOpenMultiDiGraphView(std::make_shared()); -} - -// LabelledOpenMultiDiGraph -template -LabelledOpenMultiDiGraph:: - operator LabelledOpenMultiDiGraphView() const { - return GraphInternal::create_labelled_open_multidigraph_view( - this->ptr); -} - -template -LabelledOpenMultiDiGraph::operator OpenMultiDiGraphView() const { - return GraphInternal::create_open_multidigraph_view(this->ptr.get()); -} - -template -Node LabelledOpenMultiDiGraph::add_node( - NodeLabel const &l) { - return this->ptr.get_mutable()->add_node(l); -} - -template -NodeLabel &LabelledOpenMultiDiGraph::at(Node const &n) { - return this->ptr->at(n); -} - -template -NodeLabel const & - LabelledOpenMultiDiGraph::at(Node const &n) const { - return this->ptr->ILabelledMultiDiGraph::at(n); -} - -template -void LabelledOpenMultiDiGraph::add_node_unsafe( - Node const &n, NodeLabel const &l) { - this->ptr->add_node_unsafe(n, l); -} - -template -std::unordered_set LabelledOpenMultiDiGraph::query_nodes( - NodeQuery const &q) const { - return this->ptr->query_nodes(q); -} - -template -std::unordered_set - LabelledOpenMultiDiGraph::query_edges( - OpenMultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - MultiDiEdge const &e, EdgeLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -EdgeLabel & - LabelledOpenMultiDiGraph::at(MultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -EdgeLabel const &LabelledOpenMultiDiGraph::at( - MultiDiEdge const &e) const { - return this->ptr->ILabelledMultiDiGraph::at(e); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - InputMultiDiEdge const &e, InputLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -InputLabel &LabelledOpenMultiDiGraph::at( - InputMultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -InputLabel const &LabelledOpenMultiDiGraph::at( - InputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - OutputMultiDiEdge const &e, OutputLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -OutputLabel &LabelledOpenMultiDiGraph::at( - OutputMultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -OutputLabel const &LabelledOpenMultiDiGraph::at( - OutputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -template -enable_if_t< - std::is_base_of::Interface, - BaseImpl>::value, - LabelledOpenMultiDiGraph> - LabelledOpenMultiDiGraph::create() { - return LabelledOpenMultiDiGraph(make_cow_ptr()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h deleted file mode 100644 index 2db654c615..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_INTERFACES_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_INTERFACES_H - -#include "standard_labelled_interfaces.h" -#include "utils/containers.h" -#include "utils/graph/open_graph_interfaces.h" - -namespace FlexFlow { - -template -struct ILabelledOpenMultiDiGraphView - : public IOpenMultiDiGraphView, - public ILabelledMultiDiGraphView { -public: - std::unordered_set - query_edges(MultiDiEdgeQuery const &q) const final { - return map_over_unordered_set( - [](OpenMultiDiEdge const &e) { return get(e); }, - IOpenMultiDiGraphView::query_edges( - static_cast(q))); - } - - using ILabelledMultiDiGraphView::at; - virtual InputLabel const &at(InputMultiDiEdge const &e) const = 0; - virtual OutputLabel const &at(OutputMultiDiEdge const &e) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT( - ILabelledOpenMultiDiGraphView); - -template -struct ILabelledOpenMultiDiGraph - : public ILabelledMultiDiGraph, - public ILabelledOpenMultiDiGraphView { -public: - virtual ILabelledOpenMultiDiGraph *clone() const = 0; - - virtual void add_edge(InputMultiDiEdge const &e, InputLabel const &label) = 0; - virtual void add_edge(OutputMultiDiEdge const &e, - OutputLabel const &label) = 0; - - virtual InputLabel const &at(InputMultiDiEdge const &e) const = 0; - virtual InputLabel &at(InputMultiDiEdge const &e) = 0; - - virtual OutputLabel const &at(OutputMultiDiEdge const &e) const = 0; - virtual OutputLabel &at(OutputMultiDiEdge const &e) = 0; - - using ILabelledMultiDiGraph::add_node; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 1ecd87226c..9d8874fb14 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -1,24 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_H -#include "label_interfaces.h" +#include "node_labelled_interfaces.h" #include "utils/graph/multidigraph.h" namespace FlexFlow { -template -struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { - INodeLabelledMultiDiGraphView() = default; - INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; - INodeLabelledMultiDiGraphView & - operator=(INodeLabelledMultiDiGraphView const &) = delete; - - virtual ~INodeLabelledMultiDiGraphView() {} - - virtual NodeLabel const &at(Node const &n) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); - template struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: @@ -65,7 +52,6 @@ struct NodeLabelledMultiDiGraph : virtual NodeLabelledMultiDiGraphView { private: using Interface = IMultiDiGraph; - using NodeLabelIf = ILabelling; public: NodeLabelledMultiDiGraph(NodeLabelledMultiDiGraph const &) = default; @@ -73,48 +59,42 @@ struct NodeLabelledMultiDiGraph operator=(NodeLabelledMultiDiGraph const &) = default; NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->get_ptr().at(n); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(); + return this->get_ptr().query_nodes(); } std::unordered_set query_edges(MultiDiEdge const &q) const { - return get_ptr().query_edges(); + return this->get_ptr().query_edges(); } Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } void add_edge(MultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of>::value, - NodeLabelledMultiDiGraph>::type + template + static typename std::enable_if::value, + NodeLabelledMultiDiGraph>::type create() { - return NodeLabelledMultiDiGraph(make_cow_ptr(), - make_cow_ptr()); + return NodeLabelledMultiDiGraph(make_cow_ptr()); } protected: - NodeLabelledMultiDiGraph(cow_ptr_t ptr, cow_ptr_t nl) - : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} + NodeLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -125,8 +105,6 @@ struct NodeLabelledMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h new file mode 100644 index 0000000000..37fb4db715 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_INTERFACES_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_INTERFACES_H + +#include "utils/graph/multidigraph.h" + +namespace FlexFlow { + +template +struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { + INodeLabelledMultiDiGraphView() = default; + INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; + INodeLabelledMultiDiGraphView & + operator=(INodeLabelledMultiDiGraphView const &) = delete; + + virtual ~INodeLabelledMultiDiGraphView() {} + + virtual NodeLabel const &at(Node const &n) const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); + +template +struct INodeLabelledMultiDiGraph + : virtual INodeLabelledMultiDiGraphView { + virtual NodeLabel &at(Node const &) = 0; + virtual Node add_node(NodeLabel const &l) = 0; + virtual NodePort add_node_port() = 0; + virtual void add_edge(MultiDiEdge const &) = 0; + + virtual INodeLabelledMultiDiGraph *clone() const = 0; + + using INodeLabelledMultiDiGraphView::at; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 2162ee0384..826a8387cb 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -60,64 +60,65 @@ struct NodeLabelledOpenMultiDiGraphView } }; +template +struct INodeLabelledOpenMultiDiGraph + : virtual INodeLabelledOpenMultiDiGraphView { + virtual Node add_node(NodeLabel const &) = 0; + virtual NodePort add_node_port() = 0; + virtual NodeLabel &at(Node const &) = 0; + virtual void add_edge(OpenMultiDiEdge const &e) = 0; + + using INodeLabelledOpenMultiDiGraphView::at; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledOpenMultiDiGraphView); + template struct NodeLabelledOpenMultiDiGraph : virtual NodeLabelledOpenMultiDiGraphView { private: - using Interface = IOpenMultiDiGraph; - using INodeLabel = ILabelling; + using Interface = INodeLabelledOpenMultiDiGraph; public: - // NodeLabelledOpenMultiDiGraph() = delete; NodeLabelledOpenMultiDiGraph(NodeLabelledOpenMultiDiGraph const &) = default; NodeLabelledOpenMultiDiGraph & operator=(NodeLabelledOpenMultiDiGraph const &) = default; - NodeLabel const &at(Node const &n) const { - return nl->get_label(n); - } - NodeLabel &at(Node const &n) { - return nl->get_label(n); + return this->get_ptr().at(n); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdge const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } void add_edge(OpenMultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of>::value, - NodeLabelledOpenMultiDiGraph>::type + using NodeLabelledOpenMultiDiGraphView::at; + + template + static typename std::enable_if::value, + NodeLabelledOpenMultiDiGraph>::type create() { - return NodeLabelledOpenMultiDiGraph(make_cow_ptr(), - make_cow_ptr()); + return NodeLabelledOpenMultiDiGraph(make_cow_ptr()); } private: - NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl) - : GraphView(ptr), nl(nl) {} + NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -128,8 +129,6 @@ struct NodeLabelledOpenMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 9c65db4daa..c6c521c38b 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -1,24 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H -#include "standard_labelled.h" +#include "node_labelled.h" +#include "output_labelled_interfaces.h" namespace FlexFlow { -template -struct IOutputLabelledMultiDiGraphView - : public INodeLabelledMultiDiGraphView { - IOutputLabelledMultiDiGraphView() = default; - IOutputLabelledMultiDiGraphView(IOutputLabelledMultiDiGraphView const &) = - delete; - IOutputLabelledMultiDiGraphView & - operator=(IOutputLabelledMultiDiGraphView const &) = delete; - - virtual OutputLabel const &at(MultiDiOutput const &) const = 0; - using INodeLabelledMultiDiGraphView::at; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); - template struct OutputLabelledMultiDiGraphView : virtual public NodeLabelledMultiDiGraphView { @@ -32,19 +19,19 @@ struct OutputLabelledMultiDiGraphView operator=(OutputLabelledMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr().at(n); + return this->get_ptr().at(n); } OutputLabel const &at(MultiDiOutput const &o) const { - return get_ptr().at(o); + return this->get_ptr().at(o); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -69,9 +56,7 @@ template struct OutputLabelledMultiDiGraph : virtual OutputLabelledMultiDiGraphView { private: - using Interface = IMultiDiGraph; - using INodeLabel = ILabelling; - using IOutputLabel = ILabelling; + using Interface = IOutputLabelledMultiDiGraph; public: OutputLabelledMultiDiGraph(OutputLabelledMultiDiGraph const &other) = default; @@ -79,67 +64,58 @@ struct OutputLabelledMultiDiGraph operator=(OutputLabelledMultiDiGraph const &other) = default; Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->get_ptr().at(n); } NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } void add_output(MultiDiOutput const &o, OutputLabel const &l) { - ol.get_mutable()->add_label(o, l); + this->get_ptr().add_output(o, l); }; void add_edge(MultiDiOutput const &o, MultiDiInput const &i) { - return get_ptr().add_edge(o, i); + this->get_ptr().add_edge(o, i); }; void add_edge(MultiDiEdge const &e) { - return get_ptr().add_edge(e); + this->get_ptr().add_edge(e); } OutputLabel &at(MultiDiOutput const &o) { - return ol.get_mutable()->get_label(o); + return this->get_ptr().at(o); } OutputLabel const &at(MultiDiOutput const &o) const { - return ol->get_label(o); + return this->get_ptr().at(o); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of>::value, - OutputLabelledMultiDiGraph>::type + template + static typename std::enable_if::value, + OutputLabelledMultiDiGraph>::type create() { - return OutputLabelledMultiDiGraph( - make_cow_ptr(), make_cow_ptr(), make_cow_ptr()); + return OutputLabelledMultiDiGraph(make_cow_ptr()); } private: - OutputLabelledMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t ol) - : GraphView(ptr), nl(nl), ol(ol) {} + OutputLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} private: Interface &get_ptr() { @@ -151,9 +127,6 @@ struct OutputLabelledMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; - cow_ptr_t ol; }; template struct IOutputLabelledMultiDiGraphView : public INodeLabelledMultiDiGraphView { - virtual OutputLabel &at(MultiDiOutput const &) = 0; + virtual OutputLabel const &at(MultiDiOutput const &) const = 0; + + using INodeLabelledMultiDiGraphView::at; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); template struct IOutputLabelledMultiDiGraph - : public IOutputLabelledMultiDiGraphView { + : public IOutputLabelledMultiDiGraphView, + public INodeLabelledMultiDiGraph { public: virtual IOutputLabelledMultiDiGraph *clone() const = 0; virtual void add_output(MultiDiOutput const &output, OutputLabel const &label) = 0; - virtual void add_edge(MultiDiOutput const &output, - MultiDiInput const &input) = 0; - virtual NodePort add_node_ports() = 0; + virtual NodePort add_node_port() = 0; virtual NodeLabel &at(Node const &) = 0; virtual NodeLabel const &at(Node const &) const = 0; + virtual OutputLabel &at(MultiDiOutput const &) = 0; virtual OutputLabel const &at(MultiDiOutput const &) const = 0; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 23dd9c190c..24235bee4c 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -2,19 +2,10 @@ #define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN #include "node_labelled_open.h" -#include "utils/graph/adjacency_openmultidigraph.h" +#include "output_labelled_open_interfaces.h" namespace FlexFlow { -template -struct IOutputLabelledOpenMultiDiGraphView - : virtual INodeLabelledOpenMultiDiGraphView { - virtual EdgeLabel const &at(InputMultiDiEdge const &) const = 0; - virtual EdgeLabel const &at(MultiDiOutput const &) const = 0; - - using INodeLabelledOpenMultiDiGraphView::at; -}; - template struct OutputLabelledOpenMultiDiGraphView : virtual NodeLabelledOpenMultiDiGraphView, @@ -29,15 +20,15 @@ struct OutputLabelledOpenMultiDiGraphView operator=(OutputLabelledOpenMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr().at(n); + return this->get_ptr().at(n); } EdgeLabel const &at(InputMultiDiEdge const &i) const { - return get_ptr().at(i); + return this->get_ptr().at(i); } EdgeLabel const &at(MultiDiOutput const &o) const { - return get_ptr().at(o); + return this->get_ptr().at(o); } template @@ -51,12 +42,12 @@ struct OutputLabelledOpenMultiDiGraphView } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -82,10 +73,7 @@ template struct OutputLabelledOpenMultiDiGraph : virtual OutputLabelledOpenMultiDiGraphView { private: - using Interface = IOpenMultiDiGraph; - using INodeLabel = ILabelling; - using IInputLabel = ILabelling; - using IOutputLabel = ILabelling; + using Interface = IOutputLabelledOpenMultiDiGraph; public: OutputLabelledOpenMultiDiGraph() = delete; @@ -95,14 +83,7 @@ struct OutputLabelledOpenMultiDiGraph operator=(OutputLabelledOpenMultiDiGraph const &) = default; Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - this->node_labelling.get_mutable()->add_label(n, l); - return n; - } - - void add_node_unsafe(Node const &n, NodeLabel const &l) { - this->get_ptr().add_node_unsafe(n); - this->node_labelling.get_mutable()->add_label(n, l); + return this->get_ptr().add_node(l); } NodePort add_node_port() { @@ -110,19 +91,15 @@ struct OutputLabelledOpenMultiDiGraph } NodeLabel &at(Node const &n) { - return this->node_labelling.get_mutable()->get_label(n); - } - - NodeLabel const &at(Node const &n) const { - return this->node_labelling->get_label(n); + return this->get_ptr().at(n); } void add_label(MultiDiOutput const &o, EdgeLabel const &l) { - this->output_labelling.get_mutable()->add_label(o, l); + this->get_ptr().add_label(o, l); }; void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) { - this->input_labelling.get_mutable()->add_label(e, l); + this->get_ptr().add_label(e, l); } void add_edge(OpenMultiDiEdge const &e) { @@ -130,18 +107,11 @@ struct OutputLabelledOpenMultiDiGraph } EdgeLabel &at(MultiDiOutput const &o) { - return this->output_labelling.get_mutable()->get_label(o); - } - EdgeLabel const &at(MultiDiOutput const &o) const { - return this->output_labelling->get_label(o); + return this->get_ptr().at(o); } EdgeLabel &at(InputMultiDiEdge const &e) { - return this->input_labelling.get_mutable()->get_label(e); - } - - EdgeLabel const &at(InputMultiDiEdge const &e) const { - return this->input_labelling->get_label(e); + return this->get_ptr().at(e); } template @@ -155,34 +125,24 @@ struct OutputLabelledOpenMultiDiGraph } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of, - std::is_base_of>::value, - OutputLabelledOpenMultiDiGraph>::type + template + static typename std::enable_if::value, + OutputLabelledOpenMultiDiGraph>::type create() { - return OutputLabelledOpenMultiDiGraph(make_cow_ptr(), - make_cow_ptr(), - make_cow_ptr(), - make_cow_ptr()); + return OutputLabelledOpenMultiDiGraph(make_cow_ptr()); } + using OutputLabelledOpenMultiDiGraphView::at; + private: - OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t il, - cow_ptr_t ol) - : GraphView(ptr), node_labelling(nl), input_labelling(il), - output_labelling(ol) {} + OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -193,10 +153,6 @@ struct OutputLabelledOpenMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t node_labelling; - cow_ptr_t input_labelling; - cow_ptr_t output_labelling; }; template diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h new file mode 100644 index 0000000000..501805fe2a --- /dev/null +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN_INTERFACES +#define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN_INTERFACES + +#include "node_labelled_open.h" + +namespace FlexFlow { + +template +struct IOutputLabelledOpenMultiDiGraphView + : virtual INodeLabelledOpenMultiDiGraphView { + virtual EdgeLabel const &at(InputMultiDiEdge const &) const = 0; + virtual EdgeLabel const &at(MultiDiOutput const &) const = 0; + + using INodeLabelledOpenMultiDiGraphView::at; +}; + +template +struct IOutputLabelledOpenMultiDiGraph + : virtual public IOutputLabelledOpenMultiDiGraphView { + virtual EdgeLabel &at(InputMultiDiEdge const &) = 0; + virtual EdgeLabel &at(MultiDiOutput const &) = 0; + virtual Node add_node(NodeLabel const &) = 0; + virtual NodePort add_node_port() = 0; + virtual NodeLabel &at(Node const &) = 0; + virtual void add_label(MultiDiOutput const &o, EdgeLabel const &l) = 0; + virtual void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) = 0; + virtual void add_edge(OpenMultiDiEdge const &e) = 0; + + using IOutputLabelledOpenMultiDiGraphView::at; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 3c69d62ae9..e1c8e91634 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -2,23 +2,10 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_STANDARD_LABELLED_H #include "node_labelled.h" +#include "standard_labelled_interfaces.h" namespace FlexFlow { -template -struct ILabelledMultiDiGraphView - : public INodeLabelledMultiDiGraphView { - ILabelledMultiDiGraphView() = default; - ILabelledMultiDiGraphView(ILabelledMultiDiGraphView const &) = delete; - ILabelledMultiDiGraphView & - operator=(ILabelledMultiDiGraphView const &) = delete; - - virtual ~ILabelledMultiDiGraphView() = default; - - virtual EdgeLabel const &at(MultiDiEdge const &) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledMultiDiGraphView); - template struct LabelledMultiDiGraphView : virtual public NodeLabelledMultiDiGraphView { @@ -70,19 +57,14 @@ template struct LabelledMultiDiGraph : virtual LabelledMultiDiGraphView { private: - using Interface = IMultiDiGraph; - using INodeLabel = ILabelling; - using IEdgeLabel = ILabelling; + using Interface = ILabelledMultiDiGraph; public: - // LabelledMultiDiGraph() = delete; LabelledMultiDiGraph(LabelledMultiDiGraph const &other) = default; LabelledMultiDiGraph &operator=(LabelledMultiDiGraph const &other) = default; Node add_node(NodeLabel const &l) { - Node n = MultiDiGraph::add_node(); - nl->add_label(n, l); - return n; + return this->get_ptr().add_node(); } NodePort add_node_port() { @@ -90,46 +72,36 @@ struct LabelledMultiDiGraph } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); - } - - NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } void add_edge(MultiDiEdge const &e, EdgeLabel const &l) { return this->get_ptr().add_edge(e, l); } + EdgeLabel &at(MultiDiEdge const &e) { - return el.get_mutable()->get_label(e); - } - EdgeLabel const &at(MultiDiEdge const &e) const { - return el->get_label(e); + return this->get_ptr().at(e); } std::unordered_set query_nodes(NodeQuery const &q) const { return this->get_ptr().query_nodes(q); } + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of>::value, - LabelledMultiDiGraph>::type + using LabelledMultiDiGraphView::at; + + template + static typename std::enable_if::value, + LabelledMultiDiGraph>::type create() { - return LabelledMultiDiGraph( - make_cow_ptr(), make_cow_ptr(), make_cow_ptr()); + return LabelledMultiDiGraph(make_cow_ptr()); } private: - LabelledMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t el) - : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} + LabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { return *std::reinterpret_pointer_cast( @@ -140,9 +112,6 @@ struct LabelledMultiDiGraph return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } - - cow_ptr_t nl; - cow_ptr_t el; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h index f7af522b3c..fe396e5989 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h @@ -1,138 +1,227 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_UNORDERED_LABELLED_GRAPHS_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_UNORDERED_LABELLED_GRAPHS_H -#include "labelled_open_interfaces.h" -#include "node_labelled_interfaces.h" -#include "output_labelled_interfaces.h" -#include "standard_labelled_interfaces.h" -#include "utils/graph/open_graphs.h" +#include "output_labelled_open_interfaces.h" +#include "unordered_label.h" +#include "utils/graph/adjacency_openmultidigraph.h" namespace FlexFlow { template -struct UnorderedNodeLabelledMultiDiGraph - : public INodeLabelledMultiDiGraph, - protected MultiDiGraph { -public: - UnorderedNodeLabelledMultiDiGraph() = delete; +struct UnorderedNodeLabelledOpenMultiDiGraph + : public INodeLabelledOpenMultiDiGraph { - Node add_node(NodeLabel const &label) override { - Node n = MultiDiGraph::add_node(); - node_map.insert({n, label}); - return n; + UnorderedNodeLabelledOpenMultiDiGraph() + : g(OpenMultiDiGraph::create()) {} + + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; } - NodeLabel &at(Node const &n) override { - return this->node_map.at(n); + NodePort add_node_port() override { + return this->g.add_node_port(); } NodeLabel const &at(Node const &n) const override { - return this->node_map.at(n); + return this->node_labelling.get_label(n); } - using MultiDiGraph::query_edges; - using MultiDiGraph::query_nodes; + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } -private: - std::unordered_map node_map; -}; + void add_edge(OpenMultiDiEdge const &e) override { + this->g.add_edge(e); + } -template -struct UnorderedLabelledMultiDiGraph - : public ILabelledMultiDiGraph, - public UnorderedNodeLabelledMultiDiGraph { - void add_edge(MultiDiEdge const &e, EdgeLabel const &label) override { - MultiDiGraph::add_edge(e); - edge_map.insert({e, label}); + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); } - EdgeLabel &at(MultiDiEdge const &n) override { - return this->edge_map.at(n); + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return g.query_edges(q); } - EdgeLabel const &at(MultiDiEdge const &n) const override { - return this->edge_map.at(n); + using INodeLabelledOpenMultiDiGraph::query_edges; + + UnorderedNodeLabelledOpenMultiDiGraph *clone() const override { + return new UnorderedNodeLabelledOpenMultiDiGraph(g, + node_labelling); } private: - std::unordered_map edge_map; -}; + UnorderedNodeLabelledOpenMultiDiGraph( + OpenMultiDiGraph const &g, + UnorderedLabelling const &node_labelling) + : g(g), node_labelling(node_labelling) {} -MultiDiOutput get_output(MultiDiEdge const &e); + OpenMultiDiGraph g; + UnorderedLabelling node_labelling; +}; +CHECK_NOT_ABSTRACT(UnorderedNodeLabelledOpenMultiDiGraph); template struct UnorderedOutputLabelledMultiDiGraph - : public IOutputLabelledMultiDiGraph, - public UnorderedNodeLabelledMultiDiGraph { -public: + : public IOutputLabelledMultiDiGraph { + + UnorderedOutputLabelledMultiDiGraph() + : g(MultiDiGraph::create()) {} + + OutputLabel const &at(MultiDiOutput const &i) const override { + return this->output_labelling.get_label(i); + } + + OutputLabel &at(MultiDiOutput const &i) override { + return this->output_labelling.get_label(i); + } + + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; + } + + NodePort add_node_port() override { + return this->g.add_node_port(); + } + + NodeLabel const &at(Node const &n) const override { + return this->node_labelling.get_label(n); + } + + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } + + void add_edge(MultiDiEdge const &e) override { + this->g.add_edge(e); + } + void add_output(MultiDiOutput const &output, OutputLabel const &label) override { - this->output_map.insert({output, label}); + this->output_labelling.add_label(output, label); } - void add_edge(MultiDiEdge const &e) override { - MultiDiOutput output = get_output(e); - if (!contains_key(this->output_map, output)) { - throw mk_runtime_error("Could not find output {}", output); - } - this->add_edge(e); + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); } - void add_edge(MultiDiOutput const &output, - MultiDiInput const &input) override { - this->add_edge(MultiDiEdge{output.node, input.node, output.idx, input.idx}); + std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const override { + return g.query_edges(q); + } + + using IOutputLabelledMultiDiGraph::query_edges; + + UnorderedOutputLabelledMultiDiGraph *clone() const override { + return new UnorderedOutputLabelledMultiDiGraph( + g, node_labelling, output_labelling); } private: - std::unordered_map output_map; + UnorderedOutputLabelledMultiDiGraph( + MultiDiGraph const &g, + UnorderedLabelling const &node_labelling, + UnorderedLabelling const &output_labelling) + : g(g), node_labelling(node_labelling), + output_labelling(output_labelling) {} + + MultiDiGraph g; + UnorderedLabelling node_labelling; + UnorderedLabelling output_labelling; }; +CHECK_NOT_ABSTRACT(UnorderedOutputLabelledMultiDiGraph); -template -struct UnorderedLabelledOpenMultiDiGraph - : public ILabelledOpenMultiDiGraph, - public UnorderedLabelledMultiDiGraph { -public: - void add_edge(InputMultiDiEdge const &e, InputLabel const &label) { - this->add_edge(e); - this->input_map.insert({e, label}); +template +struct UnorderedOutputLabelledOpenMultiDiGraph + : public IOutputLabelledOpenMultiDiGraph { + + UnorderedOutputLabelledOpenMultiDiGraph() + : g(OpenMultiDiGraph::create()) {} + + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + return this->input_labelling.get_label(i); } - void add_edge(OutputMultiDiEdge const &e, OutputLabel const &label) { - this->add_edge(e); - this->output_map.insert({e, label}); + EdgeLabel &at(InputMultiDiEdge const &i) override { + return this->input_labelling.get_label(i); } - InputLabel const &at(InputMultiDiEdge const &e) const { - return this->input_map.at(e); + EdgeLabel const &at(MultiDiOutput const &i) const override { + return this->output_labelling.get_label(i); } - InputLabel &at(InputMultiDiEdge const &e) { - return this->input_map.at(e); + EdgeLabel &at(MultiDiOutput const &i) override { + return this->output_labelling.get_label(i); } - OutputLabel const &at(OutputMultiDiEdge const &e) const { - return this->output_map.at(e); + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; } - OutputLabel &at(DownwardOpenMultiDiEdge const &e) { - return this->output_map.at(e); + NodePort add_node_port() override { + return this->g.add_node_port(); } - UnorderedLabelledOpenMultiDiGraph() { - NOT_IMPLEMENTED(); + NodeLabel const &at(Node const &n) const override { + return this->node_labelling.get_label(n); + } + + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } + + void add_label(MultiDiOutput const &o, EdgeLabel const &l) override { + this->output_labelling.add_label(o, l); + } + + void add_label(InputMultiDiEdge const &i, EdgeLabel const &l) override { + this->input_labelling.add_label(i, l); + } + + void add_edge(OpenMultiDiEdge const &e) override { + this->g.add_edge(e); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return this->g.query_edges(q); + } + + using IOutputLabelledOpenMultiDiGraph::query_edges; + + UnorderedOutputLabelledOpenMultiDiGraph *clone() const override { + return new UnorderedOutputLabelledOpenMultiDiGraph( + g, node_labelling, input_labelling, output_labelling); } private: - OpenMultiDiGraph base_graph; - std::unordered_map input_map; - std::unordered_map output_map; + UnorderedOutputLabelledOpenMultiDiGraph( + OpenMultiDiGraph const &g, + UnorderedLabelling const &node_labelling, + UnorderedLabelling const &input_labelling, + UnorderedLabelling const &output_labelling) + : g(g), node_labelling(node_labelling), input_labelling(input_labelling), + output_labelling(output_labelling) {} + + OpenMultiDiGraph g; + UnorderedLabelling node_labelling; + UnorderedLabelling input_labelling; + UnorderedLabelling output_labelling; }; +CHECK_NOT_ABSTRACT( + UnorderedOutputLabelledOpenMultiDiGraph); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 9c39dbf107..e31afad916 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -90,13 +90,13 @@ Impl materialize_output_labelled_multidigraph_view( } template + typename OutputLabelImpl> OutputLabelledOpenMultiDiGraph - materialize_output_labelled_open_multidigraph_view( + materialize_output_labelled_multidigraph_view( OutputLabelledOpenMultiDiGraphView const &g) { OutputLabelledOpenMultiDiGraph result = OutputLabelledOpenMultiDiGraph::template create< diff --git a/lib/utils/include/utils/graph/labelled_graphs.h b/lib/utils/include/utils/graph/labelled_graphs.h index 5c4b29038a..9cf5f0d97e 100644 --- a/lib/utils/include/utils/graph/labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled_graphs.h @@ -10,6 +10,7 @@ #include "labelled/output_labelled_open.h" #include "labelled/standard_labelled.h" #include "labelled/unordered_label.h" +#include "labelled/unordered_labelled_graphs.h" #include "labelled/views.h" #endif diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index e0bc94ca8c..c32ff6ded5 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( + return *std::reinterpret_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index be4b33129b..97253b4ab7 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -1,14 +1,14 @@ -# ff_add_test_executable( -# NAME -# utils-test -# SRC_PATTERNS -# src/*.cc -# PRIVATE_INCLUDE -# src/ -# DEPS -# utils -# doctest -# utils-test-common -# ) +ff_add_test_executable( + NAME + utils-test + SRC_PATTERNS + src/test_cow_ptr.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + doctest + utils-test-common +) add_subdirectory(common) diff --git a/lib/utils/test/src/test_cow_ptr.cc b/lib/utils/test/src/test_cow_ptr.cc new file mode 100644 index 0000000000..ce8516f21b --- /dev/null +++ b/lib/utils/test/src/test_cow_ptr.cc @@ -0,0 +1,60 @@ +#include "test/utils/doctest.h" +#include "utils/graph/cow_ptr_t.h" +#include +#include +#include + +using namespace FlexFlow; + +struct TestObject { + TestObject(int x) : x(x) {} + int x; + virtual TestObject *clone() const { + return new TestObject(x); + } +}; + +struct TestObjectDerived : public TestObject { + TestObjectDerived(int x, int y) : TestObject(x), y(y) {} + int y; + TestObjectDerived *clone() const override { + return new TestObjectDerived(x, y); + } +}; + +TEST_CASE("cow_ptr_t constructor") { + std::shared_ptr sp = std::make_shared(1); + cow_ptr_t p1(sp); + cow_ptr_t p2(std::make_shared(3)); + cow_ptr_t p3(TestObject(2)); + cow_ptr_t p4(p3); + cow_ptr_t p5 = p1; + CHECK(p1->x == 1); + CHECK(p2->x == 3); + CHECK(p3->x == 2); + CHECK(p4->x == p3->x); + CHECK(p5->x == p1->x); +} + +TEST_CASE("cow_ptr_t copy") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(std::make_shared(2)); + p1 = p2; + CHECK(p1->x == p2->x); +} + +TEST_CASE("cow_ptr_t cast") { + cow_ptr_t p1(std::make_shared(1, 2)); + cow_ptr_t p2(p1); + CHECK(p2->x == 1); +} + +TEST_CASE("cow_ptr_t get_mutable") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(p1); + p1.get_mutable()->x = 3; + CHECK(p1->x == 3); + CHECK(p2->x == 1); + p2.get_mutable()->x = 2; + CHECK(p1->x == 3); +} From c0015df306fca409d9b6b08edfdee548edae3a3c Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 18 Mar 2024 15:52:02 -0400 Subject: [PATCH 18/37] fmt --- .../include/utils/graph/labelled/node_labelled_interfaces.h | 2 +- lib/utils/test/src/test_cow_ptr.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h index 37fb4db715..c371a9a3bd 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h @@ -20,7 +20,7 @@ CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); template struct INodeLabelledMultiDiGraph - : virtual INodeLabelledMultiDiGraphView { + : virtual INodeLabelledMultiDiGraphView { virtual NodeLabel &at(Node const &) = 0; virtual Node add_node(NodeLabel const &l) = 0; virtual NodePort add_node_port() = 0; diff --git a/lib/utils/test/src/test_cow_ptr.cc b/lib/utils/test/src/test_cow_ptr.cc index ce8516f21b..62406bddec 100644 --- a/lib/utils/test/src/test_cow_ptr.cc +++ b/lib/utils/test/src/test_cow_ptr.cc @@ -16,7 +16,7 @@ struct TestObject { struct TestObjectDerived : public TestObject { TestObjectDerived(int x, int y) : TestObject(x), y(y) {} - int y; + int y; TestObjectDerived *clone() const override { return new TestObjectDerived(x, y); } From 102f5fb2ed3c0440ecb8288d0aa04789ea16f2b8 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 22 Mar 2024 13:54:22 -0700 Subject: [PATCH 19/37] Fix post-merge --- .flake/patches/doctest-template-test.patch | 50 ++++++ .flake/pkgs/fmt.nix | 73 ++++++++ .flake/pkgs/rapidcheck.nix | 48 ++++++ .github/workflows/helpers/build_libs.sh | 9 + .../helpers/{build_cuda.sh => cmake_cuda.sh} | 17 +- .github/workflows/helpers/test_libs.sh | 14 ++ .github/workflows/per-lib-check.yml | 38 ++--- CMakeLists.txt | 2 +- cmake/doctest.cmake | 9 - cmake/doctestlib.cmake | 11 ++ cmake/flexflow-utils.cmake | 4 +- cmake/fmt.cmake | 3 +- cmake/nccl.cmake | 1 + cmake/rapidcheck.cmake | 6 +- cmake/spdlog.cmake | 6 +- flake.nix | 79 ++++++--- lib/compiler/CMakeLists.txt | 3 +- lib/compiler/include/compiler/compiler.h | 4 +- .../include/compiler/machine_mapping.h | 2 +- lib/compiler/src/graph_utils.cc | 4 +- lib/compiler/src/machine_mapping.cc | 8 +- .../test/src/test_labelled_open_graph.cc | 2 + .../include/op-attrs/operator_attrs.h | 5 +- lib/pcg/include/pcg/device_id.h | 1 + lib/pcg/include/pcg/optimizer.h | 22 +-- .../include/substitutions/attribute_expr.h | 2 +- .../include/substitutions/get_attribute.h | 52 +++--- .../include/substitutions/operator_pattern.h | 6 +- .../include/substitutions/output_graph.h | 2 +- .../substitutions/parallel_tensor_pattern.h | 4 +- lib/substitutions/src/graph_pattern.cc | 88 +++++----- lib/substitutions/src/graph_pattern_match.cc | 24 +-- lib/substitutions/src/operator_attributes.cc | 110 ++++++------ lib/substitutions/src/substitution.cc | 156 +++++++++--------- lib/utils/include/utils/containers.decl.h | 8 +- lib/utils/include/utils/containers.h | 4 +- lib/utils/include/utils/dot_file.h | 7 +- .../graph/labelled/output_labelled_open.h | 4 +- lib/utils/include/utils/variant.h | 2 +- lib/utils/src/graph/open_edge.cc | 6 +- lib/utils/src/graph/serialparallel.cc | 6 +- lib/utils/test/src/test_variant.cc | 42 ++--- 42 files changed, 586 insertions(+), 358 deletions(-) create mode 100644 .flake/patches/doctest-template-test.patch create mode 100644 .flake/pkgs/fmt.nix create mode 100644 .flake/pkgs/rapidcheck.nix create mode 100755 .github/workflows/helpers/build_libs.sh rename .github/workflows/helpers/{build_cuda.sh => cmake_cuda.sh} (67%) create mode 100755 .github/workflows/helpers/test_libs.sh delete mode 100644 cmake/doctest.cmake create mode 100644 cmake/doctestlib.cmake diff --git a/.flake/patches/doctest-template-test.patch b/.flake/patches/doctest-template-test.patch new file mode 100644 index 0000000000..ca4d0d9a18 --- /dev/null +++ b/.flake/patches/doctest-template-test.patch @@ -0,0 +1,50 @@ +diff --git a/scripts/cmake/doctestAddTests.cmake b/scripts/cmake/doctestAddTests.cmake +index 3b25485..d3ba906 100644 +--- a/scripts/cmake/doctestAddTests.cmake ++++ b/scripts/cmake/doctestAddTests.cmake +@@ -56,12 +56,14 @@ foreach(line ${output}) + if("${line}" STREQUAL "===============================================================================" OR "${line}" MATCHES [==[^\[doctest\] ]==]) + continue() + endif() +- set(test ${line}) ++ set(unescaped_test ${line}) ++ # use escape commas to handle properly test cases with commas inside the name ++ string(REPLACE "," "\\," escaped_test ${unescaped_test}) + set(labels "") + if(${add_labels}) + # get test suite that test belongs to + execute_process( +- COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" --test-case=${test} --list-test-suites ++ COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" --test-case=${escaped_test} --list-test-suites + OUTPUT_VARIABLE labeloutput + RESULT_VARIABLE labelresult + WORKING_DIRECTORY "${TEST_WORKING_DIR}" +@@ -85,24 +87,22 @@ foreach(line ${output}) + + if(NOT "${junit_output_dir}" STREQUAL "") + # turn testname into a valid filename by replacing all special characters with "-" +- string(REGEX REPLACE "[/\\:\"|<>]" "-" test_filename "${test}") ++ string(REGEX REPLACE "[/\\:\"|<>]" "-" test_filename "${unescaped_test}") + set(TEST_JUNIT_OUTPUT_PARAM "--reporters=junit" "--out=${junit_output_dir}/${prefix}${test_filename}${suffix}.xml") + else() + unset(TEST_JUNIT_OUTPUT_PARAM) + endif() +- # use escape commas to handle properly test cases with commas inside the name +- string(REPLACE "," "\\," test_name ${test}) + # ...and add to script + add_command(add_test +- "${prefix}${test}${suffix}" ++ "${prefix}${unescaped_test}${suffix}" + ${TEST_EXECUTOR} + "${TEST_EXECUTABLE}" +- "--test-case=${test_name}" ++ "--test-case=${escaped_test}" + "${TEST_JUNIT_OUTPUT_PARAM}" + ${extra_args} + ) + add_command(set_tests_properties +- "${prefix}${test}${suffix}" ++ "${prefix}${unescaped_test}${suffix}" + PROPERTIES + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + ${properties} diff --git a/.flake/pkgs/fmt.nix b/.flake/pkgs/fmt.nix new file mode 100644 index 0000000000..e2677bdea2 --- /dev/null +++ b/.flake/pkgs/fmt.nix @@ -0,0 +1,73 @@ +{ lib +, stdenv +, fetchFromGitHub, fetchpatch +, cmake +, enableShared ? !stdenv.hostPlatform.isStatic + +# tests +, mpd +, openimageio +, fcitx5 +, spdlog +}: + +let + generic = { version, sha256, patches ? [ ] }: + stdenv.mkDerivation { + pname = "fmt"; + inherit version; + + outputs = [ "out" "dev" ]; + + src = fetchFromGitHub { + owner = "fmtlib"; + repo = "fmt"; + rev = version; + inherit sha256; + }; + + inherit patches; + + nativeBuildInputs = [ cmake ]; + + cmakeFlags = [ + "-DBUILD_SHARED_LIBS=${if enableShared then "ON" else "OFF"}" + ]; + + doCheck = true; + + passthru.tests = { + inherit mpd openimageio fcitx5 spdlog; + }; + + meta = with lib; { + description = "Small, safe and fast formatting library"; + longDescription = '' + fmt (formerly cppformat) is an open-source formatting library. It can be + used as a fast and safe alternative to printf and IOStreams. + ''; + homepage = "https://fmt.dev/"; + changelog = "https://github.com/fmtlib/fmt/blob/${version}/ChangeLog.rst"; + downloadPage = "https://github.com/fmtlib/fmt/"; + maintainers = [ maintainers.jdehaas ]; + license = licenses.mit; + platforms = platforms.all; + }; + }; +in +{ + fmt_8 = generic { + version = "8.1.1"; + sha256 = "sha256-leb2800CwdZMJRWF5b1Y9ocK0jXpOX/nwo95icDf308="; + }; + + fmt_9 = generic { + version = "9.1.0"; + sha256 = "sha256-rP6ymyRc7LnKxUXwPpzhHOQvpJkpnRFOt2ctvUNlYI0="; + }; + + fmt_10 = generic { + version = "10.1.1"; + sha256 = "sha256-H9+1lEaHM12nzXSmo9m8S6527t+97e6necayyjCPm1A="; + }; +} diff --git a/.flake/pkgs/rapidcheck.nix b/.flake/pkgs/rapidcheck.nix new file mode 100644 index 0000000000..3ff63207b2 --- /dev/null +++ b/.flake/pkgs/rapidcheck.nix @@ -0,0 +1,48 @@ +{ lib +, stdenv +, fetchFromGitHub +, cmake +, unstableGitUpdater +, testers +}: + +stdenv.mkDerivation (finalAttrs: { + pname = "rapidcheck"; + version = "unstable-2023-12-14"; + + src = fetchFromGitHub { + owner = "emil-e"; + repo = "rapidcheck"; + rev = "ff6af6fc683159deb51c543b065eba14dfcf329b"; + hash = "sha256-Ixz5RpY0n8Un/Pv4XoTfbs40+70iyMbkQUjDqoLaWOg="; + }; + + nativeBuildInputs = [ cmake ]; + + cmakeFlags = [ + (lib.cmakeBool "BUILD_SHARED_LIBS" (!stdenv.hostPlatform.isStatic)) + (lib.cmakeBool "RC_INSTALL_ALL_EXTRAS" true) + ]; + + passthru = { + updateScript = unstableGitUpdater { }; + tests.pkg-config = testers.testMetaPkgConfig finalAttrs.finalPackage; + }; + + meta = with lib; { + description = "A C++ framework for property based testing inspired by QuickCheck"; + inherit (finalAttrs.src.meta) homepage; + maintainers = with maintainers; [ ]; + license = licenses.bsd2; + pkgConfigModules = [ + "rapidcheck" + # Extras + "rapidcheck_boost" + "rapidcheck_boost_test" + "rapidcheck_catch" + "rapidcheck_doctest" + "rapidcheck_gtest" + ]; + platforms = platforms.all; + }; +}) diff --git a/.github/workflows/helpers/build_libs.sh b/.github/workflows/helpers/build_libs.sh new file mode 100755 index 0000000000..cc4e25cc0b --- /dev/null +++ b/.github/workflows/helpers/build_libs.sh @@ -0,0 +1,9 @@ +#! /usr/bin/env bash + +set -euo pipefail + +DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" +REPO="$(realpath -- "$DIR/../../../")" + +cd "$REPO/build-ci" +make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "$@" diff --git a/.github/workflows/helpers/build_cuda.sh b/.github/workflows/helpers/cmake_cuda.sh similarity index 67% rename from .github/workflows/helpers/build_cuda.sh rename to .github/workflows/helpers/cmake_cuda.sh index 3524f885a7..e549859a5a 100755 --- a/.github/workflows/helpers/build_cuda.sh +++ b/.github/workflows/helpers/cmake_cuda.sh @@ -8,22 +8,21 @@ REPO="$(realpath -- "$DIR/../../../")" export FF_GPU_BACKEND="cuda" export FF_CUDA_ARCH=70 -cd "$REPO" -mkdir build -cd build + +if [[ -d "$REPO/build-ci" ]]; then + rm -rf "$REPO/build-ci" +fi +mkdir "$REPO/build-ci" +cd "$REPO/build-ci" #if [[ "${FF_GPU_BACKEND}" == "cuda" ]]; then # export FF_BUILD_ALL_EXAMPLES=ON # export FF_BUILD_UNIT_TESTS=ON #fi +IFS=" " read -r -a FLAGS <<< "$CMAKE_FLAGS" ../config/config.linux \ - -DCMAKE_CXX_COMPILER="clang++" \ - -DCMAKE_C_COMPILER="clang" \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ - -DFF_USE_EXTERNAL_LEGION=ON \ - -DFF_USE_EXTERNAL_JSON=ON \ - -DFF_USE_EXTERNAL_FMT=ON \ - -DFF_USE_EXTERNAL_SPDLOG=ON + "${FLAGS[@]}" # vim: set tabstop=2 shiftwidth=2 expandtab: diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_libs.sh new file mode 100755 index 0000000000..7662a7e601 --- /dev/null +++ b/.github/workflows/helpers/test_libs.sh @@ -0,0 +1,14 @@ +#! /usr/bin/env bash + +set -euo pipefail +set -x + +DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" +REPO="$(realpath -- "$DIR/../../../")" + +TEST_LIBS=("${@/%/-tests}") +REGEX="^$(IFS='|'; echo "${TEST_LIBS[*]}")\$" + +cd "$REPO/build-ci" +make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "${TEST_LIBS[@]}" +ctest --progress --output-on-failure -L "$REGEX" diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 4685983ce0..f1d069f252 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -20,6 +20,9 @@ jobs: with: submodules: recursive + - name: Add helpers directory to path + run: echo "${PWD}/.github/workflows/helpers" >> $GITHUB_PATH + - name: Install nix uses: cachix/install-nix-action@v25 with: @@ -51,49 +54,36 @@ jobs: - name: Run cmake run: | - .github/workflows/helpers/build_${{ matrix.gpu_backend }}.sh + cmake_${{ matrix.gpu_backend }}.sh - name: Build utils run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) utils + build_libs.sh utils - name: Build op-attrs run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) op-attrs + build_libs.sh op-attrs - name: Build pcg run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) pcg + build_libs.sh pcg - name: Build kernels run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) kernels + build_libs.sh kernels - name: Build substitutions run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions + build_libs.sh substitutions - name: Build compiler run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler - - - name: Build substitutions-test - run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) substitutions-test + build_libs.sh compiler - - name: Build compiler-test + - name: Test substitutions run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) compiler-test + test_libs.sh substitutions - - name: Unit tests + - name: Test compiler run: | - cd build - ctest + test_libs.sh compiler diff --git a/CMakeLists.txt b/CMakeLists.txt index e04aa622c2..032bf1ac55 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,7 +84,7 @@ include(nccl) include(json) include(expected) include(spdlog) -include(doctest) +include(doctestlib) # named doctestlib to avoid a name collision with doctest.cmake in rapidcheck include(visit_struct) include(CTest) include(fmt) diff --git a/cmake/doctest.cmake b/cmake/doctest.cmake deleted file mode 100644 index b2d5243574..0000000000 --- a/cmake/doctest.cmake +++ /dev/null @@ -1,9 +0,0 @@ -include(aliasing) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest) -include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake) - -add_library(doctest-ff INTERFACE) -target_compile_definitions(doctest-ff INTERFACE DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS) -target_link_libraries(doctest-ff INTERFACE doctest::doctest) -alias_library(doctest doctest-ff) diff --git a/cmake/doctestlib.cmake b/cmake/doctestlib.cmake new file mode 100644 index 0000000000..5f29d94fd0 --- /dev/null +++ b/cmake/doctestlib.cmake @@ -0,0 +1,11 @@ +include(aliasing) + +if (FF_USE_EXTERNAL_DOCTEST) + find_package(doctest REQUIRED) + include(doctest) # import doctest_discover_tests +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest) + include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake) +endif() + +alias_library(doctest doctest::doctest) diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index d41573acab..4cf5450942 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -118,7 +118,9 @@ function(ff_add_test_executable) ${FF_TEST_EXEC_NAME} ${FF_TEST_EXEC_DEPS}) + target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}") + define_ff_vars(${FF_TEST_EXEC_NAME}) ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) - doctest_discover_tests(${FF_TEST_EXEC_NAME}) + doctest_discover_tests(${FF_TEST_EXEC_NAME} ADD_LABELS 1) endfunction() diff --git a/cmake/fmt.cmake b/cmake/fmt.cmake index 283caad69d..470de6a847 100644 --- a/cmake/fmt.cmake +++ b/cmake/fmt.cmake @@ -4,6 +4,5 @@ if (FF_USE_EXTERNAL_FMT) find_package(fmt REQUIRED) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/fmt) - - alias_library(fmt fmt::fmt) endif() +alias_library(fmt fmt::fmt) diff --git a/cmake/nccl.cmake b/cmake/nccl.cmake index e89bee04c6..755fe00f1b 100644 --- a/cmake/nccl.cmake +++ b/cmake/nccl.cmake @@ -8,6 +8,7 @@ else() message(STATUS "Building NCCL from source") list(TRANSFORM CUDA_GENCODE PREPEND "NVCC_GENCODE=" OUTPUT_VARIABLE NCCL_BUILD_NVCC_GENCODE) + include(ExternalProject) ExternalProject_Add(nccl_source_build SOURCE_DIR ${PROJECT_SOURCE_DIR}/deps/${NCCL_NAME} PREFIX ${CMAKE_BINARY_DIR}/deps/${NCCL_NAME} diff --git a/cmake/rapidcheck.cmake b/cmake/rapidcheck.cmake index 1ff64bd974..bf8f058e63 100644 --- a/cmake/rapidcheck.cmake +++ b/cmake/rapidcheck.cmake @@ -1 +1,5 @@ -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/rapidcheck) +if (FF_USE_EXTERNAL_RAPIDCHECK) + find_package(rapidcheck REQUIRED) +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/rapidcheck) +endif() diff --git a/cmake/spdlog.cmake b/cmake/spdlog.cmake index cd18944460..02021fd51e 100644 --- a/cmake/spdlog.cmake +++ b/cmake/spdlog.cmake @@ -4,6 +4,8 @@ if (FF_USE_EXTERNAL_SPDLOG) find_package(spdlog REQUIRED) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/spdlog) - - alias_library(spdlog spdlog::spdlog) endif() + +add_library(spdlog INTERFACE) +target_link_libraries(spdlog INTERFACE spdlog::spdlog) +target_compile_definitions(spdlog INTERFACE SPDLOG_FMT_EXTERNAL) diff --git a/flake.nix b/flake.nix index 3d357ca86c..540d0f9a94 100644 --- a/flake.nix +++ b/flake.nix @@ -13,7 +13,6 @@ ]; }; - # Nixpkgs / NixOS version to use. inputs = { nixpkgs.url = "nixpkgs/nixos-23.11"; flake-utils.url = "github:numtide/flake-utils"; @@ -25,51 +24,84 @@ inherit system; config.allowUnfree = true; }; + lib = pkgs.lib; mkShell = pkgs.mkShell.override { - stdenv = pkgs.llvmPackages.libcxxStdenv; + stdenv = pkgs.cudaPackages.backendStdenv; }; in - { - packages = { - legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + { + packages = { + legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + rapidcheckFull = pkgs.symlinkJoin { + name = "rapidcheckFull"; + paths = (with pkgs; [ rapidcheck.out rapidcheck.dev ]); }; + doctest = pkgs.doctest.overrideAttrs ( old: rec { + version = "2.4.9"; + src = pkgs.fetchFromGitHub { + owner = "doctest"; + repo = "doctest"; + rev = "v${version}"; + sha256 = "sha256-ugmkeX2PN4xzxAZpWgswl4zd2u125Q/ADSKzqTfnd94="; + }; + patches = [ + ./.flake/patches/doctest-template-test.patch + ]; + }); + }; - devShells = rec { - ci = mkShell { - buildInputs = (with pkgs; [ - llvmPackages_17.clang - cmakeCurses - gcc10Stdenv - gcc10 - ccache - cudatoolkit + devShells = rec { + ci = mkShell { + CMAKE_FLAGS = lib.strings.concatStringsSep " " [ + "-DFF_USE_EXTERNAL_LEGION=ON" + "-DFF_USE_EXTERNAL_NCCL=ON" + "-DFF_USE_EXTERNAL_JSON=ON" + "-DFF_USE_EXTERNAL_FMT=ON" + "-DFF_USE_EXTERNAL_SPDLOG=ON" + "-DFF_USE_EXTERNAL_DOCTEST=ON" + "-DFF_USE_EXTERNAL_RAPIDCHECK=ON" + "-DFF_USE_EXTERNAL_RANGEV3=ON" + "-DFF_USE_EXTERNAL_BOOST_PREPROCESSOR=ON" + "-DFF_USE_EXTERNAL_TYPE_INDEX=ON" + ]; + + buildInputs = builtins.concatLists [ + (with pkgs; [ zlib - pkg-config - python3 - self.packages.${system}.legion + boost nlohmann_json spdlog range-v3 - rapidcheck - doctest fmt + cmakeCurses + ccache + pkg-config + python3 + cudatoolkit cudaPackages.cuda_nvcc cudaPackages.cudnn cudaPackages.nccl cudaPackages.libcublas cudaPackages.cuda_cudart - ]) ++ (with pkgs.python3Packages; [ - ]); + ]) + (with self.packages.${system}; [ + legion + rapidcheckFull + doctest + ]) + ]; }; default = mkShell { inputsFrom = [ ci ]; - + inherit (ci) CMAKE_FLAGS; + buildInputs = builtins.concatLists [ (with pkgs; [ - clang-tools_17 + ccls gh-markdown-preview + shellcheck plantuml gdb ruff @@ -96,4 +128,3 @@ } ); } -# vim: set tabstop=2 shiftwidth=2 expandtab: diff --git a/lib/compiler/CMakeLists.txt b/lib/compiler/CMakeLists.txt index 6610834eed..a2933efa50 100644 --- a/lib/compiler/CMakeLists.txt +++ b/lib/compiler/CMakeLists.txt @@ -11,11 +11,10 @@ ff_add_library( op-attrs utils json - optional pcg spdlog substitutions ) add_subdirectory(ffi) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/compiler/include/compiler/compiler.h b/lib/compiler/include/compiler/compiler.h index 3a75e3a9bf..a4f7b0ecd3 100644 --- a/lib/compiler/include/compiler/compiler.h +++ b/lib/compiler/include/compiler/compiler.h @@ -12,8 +12,8 @@ enum class SearchAlgorithm { DATA_PARALLEL, }; -using SearchAlgorithmConfig = variant<>; -using SearchSolution = variant<>; +using SearchAlgorithmConfig = std::variant<>; +using SearchSolution = std::variant<>; struct SearchResult { ParallelComputationGraph pcg; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 185f2706ef..8b21b9522f 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -53,7 +53,7 @@ class OptimalCostCache { public: OptimalCostCache() = default; - optional load(OptimalCostState const &) const; + std::optional load(OptimalCostState const &) const; void save(OptimalCostState const &, OptimalCostResult const &); private: diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 3c6e44216b..5b76beb8c0 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -125,14 +125,14 @@ std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { std::unordered_set get_nodes(Serial const &serial) { return set_union( - transform(serial.children, [](variant const child) { + transform(serial.children, [](std::variant const child) { return visit(GetNodes{}, child); })); } std::unordered_set get_nodes(Parallel const ¶llel) { return set_union( - transform(parallel.children, [](variant const child) { + transform(parallel.children, [](std::variant const child) { return visit(GetNodes{}, child); })); } diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index b48e200c15..2b08e9fe23 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -43,13 +43,13 @@ bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, return lhs.runtime < rhs.runtime; } -optional +std::optional OptimalCostCache::load(OptimalCostState const &state) const { if (contains_key(cache, state)) { OptimalCostResult result = cache.at(state); - return make_optional(result); + return std::make_optional(result); } - return nullopt; + return std::nullopt; } void OptimalCostCache::save(OptimalCostState const &state, @@ -152,7 +152,7 @@ struct MachineMappingSearcher { OptimalCostResult operator()(T const &t) { OptimalCostState state{ t, resource, given_machine_views, frontier_machine_views}; - optional cached_result = + std::optional cached_result = searcher->cached_subgraph_costs.load(state); if (cached_result) { diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index a3b6319528..dfe1f6301c 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,6 +4,7 @@ using namespace FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { // TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { // auto g = OpenMultiDiGraph::create(); @@ -142,3 +143,4 @@ TEST_CASE("OutputLabelledOpenMultiDiGraph") { // CHECK(get_edges(g).size() == 1); // } +} diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 9da787cbf8..678a049c3b 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -32,6 +32,7 @@ #include "ops/topk.h" #include "ops/transpose.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -85,8 +86,8 @@ static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); -using ParallelOperatorAttrs = std:: - variant; +using ParallelOperatorAttrs = + std::variant; using ComputationGraphAttrs = variant_join>; diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index 50c2558e39..b118d69259 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -3,6 +3,7 @@ #include "device_type.h" #include "utils/strong_typedef.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/optimizer.h b/lib/pcg/include/pcg/optimizer.h index df5bddf729..0bb3fab974 100644 --- a/lib/pcg/include/pcg/optimizer.h +++ b/lib/pcg/include/pcg/optimizer.h @@ -7,21 +7,21 @@ namespace FlexFlow { struct SGDOptimizer { - req lr; - req momentum; - req nesterov; + double lr; + double momentum; + bool nesterov; req weight_decay; }; FF_VISITABLE_STRUCT(SGDOptimizer, lr, momentum, nesterov, weight_decay); struct AdamOptimizer { - req alpha; - req beta1; - req beta2; - req weight_decay; - req epsilon; - req alpha_t; - req beta_t; + double alpha; + double beta1; + double beta2; + double weight_decay; + double epsilon; + double alpha_t; + double beta_t; req beta2_t; }; FF_VISITABLE_STRUCT(AdamOptimizer, @@ -34,7 +34,7 @@ FF_VISITABLE_STRUCT(AdamOptimizer, beta_t, beta2_t); -using Optimizer = variant; +using Optimizer = std::variant; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h index d6902d1274..0afd48b431 100644 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ b/lib/substitutions/include/substitutions/attribute_expr.h @@ -19,7 +19,7 @@ struct ListSize { }; template -using AttributeExpr = variant, ListSize>; +using AttributeExpr = std::variant, ListSize>; template struct AttributeConstraint { diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index 50c4108a67..7088730c53 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -7,57 +7,57 @@ namespace FlexFlow { -optional get_attribute(PCGOperatorAttrs const &, +std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); -optional get_attribute(BatchMatmulAttrs const &p, +std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey); -optional get_attribute(CastAttrs const &p, +std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey); -optional get_attribute(CombineAttrs const &p, +std::optional get_attribute(CombineAttrs const &p, OperatorAttributeKey); -optional get_attribute(ConcatAttrs const &p, +std::optional get_attribute(ConcatAttrs const &p, OperatorAttributeKey); -optional get_attribute(Conv2DAttrs const &p, +std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementBinaryAttrs const &p, +std::optional get_attribute(ElementBinaryAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementUnaryAttrs const &p, +std::optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey); -optional get_attribute(DropoutAttrs const &p, +std::optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey); -optional get_attribute(ElementScalarUnaryAttrs const &p, +std::optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey); -optional get_attribute(EmbeddingAttrs const &p, +std::optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey); -optional get_attribute(FlatAttrs const &p, +std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey); -optional get_attribute(GatherAttrs const &p, +std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey); -optional get_attribute(LayerNormAttrs const &p, +std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey); -optional get_attribute(LinearAttrs const &p, +std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey); -optional get_attribute(MultiHeadAttentionAttrs const &p, +std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); -optional get_attribute(Pool2DAttrs const &p, +std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReduceAttrs const &p, +std::optional get_attribute(ReduceAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReductionAttrs const &p, +std::optional get_attribute(ReductionAttrs const &p, OperatorAttributeKey); -optional get_attribute(RepartitionAttrs const &p, +std::optional get_attribute(RepartitionAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReplicateAttrs const &p, +std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey); -optional get_attribute(ReshapeAttrs const &p, +std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey); -optional get_attribute(SplitAttrs const &p, +std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey); -optional get_attribute(SoftmaxAttrs const &p, +std::optional get_attribute(SoftmaxAttrs const &p, OperatorAttributeKey); -optional get_attribute(TopKAttrs const &p, +std::optional get_attribute(TopKAttrs const &p, OperatorAttributeKey); -optional get_attribute(TransposeAttrs const &p, +std::optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey); // optional get_attribute(FusedParallelOpAttrs const &p, // OperatorAttributeKey); diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 9392a7876e..35544f3003 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -70,7 +70,7 @@ enum class OperatorAttributeKey { NUM_INPUTS }; -using OperatorAttributeValue = variant, @@ -81,7 +81,7 @@ using OperatorAttributeValue = variant, - optional, + std::optional, PoolOp, TensorShape, DataType>; @@ -97,7 +97,7 @@ using OperatorAttributeConstraint = using OperatorPattern = AttributePattern; -optional +std::optional evaluate_attribute_expr(Operator const &attrs, AttributeExpr const &expr); diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index b9cf1f53f3..4ed90aed06 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -15,7 +15,7 @@ struct AttrConstant { OperatorAttributeValue value; }; -using OperatorAttributeExpr = variant; +using OperatorAttributeExpr = std::variant; // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index d07a1da23b..741554142f 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -8,7 +8,7 @@ namespace FlexFlow { enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; -using TensorAttributeValue = variant>; +using TensorAttributeValue = std::variant>; using TensorAttributeConstraint = AttributeConstraint; @@ -16,7 +16,7 @@ using TensorAttributeConstraint = using ParallelTensorPattern = AttributePattern; -optional +std::optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, AttributeExpr const &expr); diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 1dba5c4af8..6f933dd300 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -9,51 +9,51 @@ namespace FlexFlow { -optional +std::optional evaluate_list_index_access(int index, - optional const &v) { + std::optional const &v) { if (!v.has_value() || - !holds_alternative>(v.value()) || - !holds_alternative>(v.value())) { - return nullopt; + !std::holds_alternative>(v.value()) || + !std::holds_alternative>(v.value())) { + return std::nullopt; } if (index >= MAX_TENSOR_DIM) { - return nullopt; + return std::nullopt; } - if (holds_alternative>(v.value())) { + if (std::holds_alternative>(v.value())) { return get>(v.value()).at(index); } else { return get>(v.value()).at(index); } } -optional +std::optional evaluate_list_index_access(int const &index, - optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; + std::optional const &v) { + if (!v.has_value() || !std::holds_alternative>(v.value())) { + return std::nullopt; } auto vec = get>(v.value()); if (index >= vec.size()) { - return nullopt; + return std::nullopt; } return vec.at(index); } -optional - evaluate_list_size(optional const &v) { +std::optional + evaluate_list_size(std::optional const &v) { return MAX_TENSOR_DIM; } -optional - evaluate_list_size(optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; +std::optional + evaluate_list_size(std::optional const &v) { + if (!v.has_value() || !std::holds_alternative>(v.value())) { + return std::nullopt; } return (int)get>(v.value()).size(); @@ -62,20 +62,20 @@ optional struct EvaluateOperatorAttributeExpr { EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - optional operator()(OperatorAttributeKey const &key) { + std::optional operator()(OperatorAttributeKey const &key) { return get_attribute(this->attrs.attrs, key); } - optional + std::optional operator()(ListIndexAccess const &index_access) { - optional v = + std::optional v = get_attribute(this->attrs.attrs, index_access.attribute_key); return evaluate_list_index_access(index_access.index, v); } - optional + std::optional operator()(ListSize const &list_size) { - optional v = + std::optional v = get_attribute(this->attrs.attrs, list_size.attribute_key); return evaluate_list_size(v); } @@ -84,7 +84,7 @@ struct EvaluateOperatorAttributeExpr { Operator attrs; }; -optional +std::optional evaluate_tensor_attribute_expr(ParallelTensor const &, AttributeExpr const &); @@ -93,11 +93,11 @@ struct EvaluateTensorAttributeExpr { : tensor_shape(tensor_shape) {} template - optional evaluate(T const &t) { + std::optional evaluate(T const &t) { return this->operator()(t); } - optional operator()(TensorAttributeKey key) { + std::optional operator()(TensorAttributeKey key) { switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector result; @@ -118,14 +118,14 @@ struct EvaluateTensorAttributeExpr { } } - optional + std::optional operator()(ListIndexAccess const &index_access) { - optional v = + std::optional v = this->evaluate(index_access.attribute_key); return evaluate_list_index_access(index_access.index, v); } - optional + std::optional operator()(ListSize const &list_size) { return evaluate_list_size(this->evaluate(list_size.attribute_key)); } @@ -134,29 +134,29 @@ struct EvaluateTensorAttributeExpr { ParallelTensor tensor_shape; }; -optional +std::optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, AttributeExpr const &expr) { return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); } -optional +std::optional evaluate_attribute_expr(Operator const &attrs, AttributeExpr const &expr) { return visit(EvaluateOperatorAttributeExpr(attrs), expr); } template -optional satisfies(ConstraintType constraint_type, +std::optional satisfies(ConstraintType constraint_type, V const &constraint_value, - optional const &maybe_attribute_value) { + std::optional const &maybe_attribute_value) { if (!maybe_attribute_value.has_value()) { - return nullopt; + return std::nullopt; } V attr_val = maybe_attribute_value.value(); if (attr_val.index() != constraint_value.index()) { - return nullopt; + return std::nullopt; } if (constraint_type == ConstraintType::EQUAL) { @@ -166,14 +166,14 @@ optional satisfies(ConstraintType constraint_type, } } -optional satisfies(ParallelTensor const &tensor_shape, +std::optional satisfies(ParallelTensor const &tensor_shape, TensorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); return satisfies( constraint.constraint_type, constraint.attribute_value, value); } -optional satisfies(Operator const ¶ms, +std::optional satisfies(Operator const ¶ms, OperatorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(params, constraint.attribute_expr); OperatorAttributeValue v = value.value(); @@ -182,12 +182,12 @@ optional satisfies(Operator const ¶ms, } template -optional optional_all_of(Container const &container, +std::optional optional_all_of(Container const &container, Function const &func) { for (auto const &element : container) { - optional condition = func(element); + std::optional condition = func(element); if (!condition.has_value()) { - return nullopt; + return std::nullopt; } if (!condition.value()) { @@ -197,7 +197,7 @@ optional optional_all_of(Container const &container, return true; } -optional satisfies(Operator const ¶ms, +std::optional satisfies(Operator const ¶ms, OperatorPattern const &pattern) { return optional_all_of(pattern.attribute_constraints, [&](OperatorAttributeConstraint const &c) { @@ -205,7 +205,7 @@ optional satisfies(Operator const ¶ms, }); } -optional satisfies(ParallelTensor const ¶ms, +std::optional satisfies(ParallelTensor const ¶ms, ParallelTensorPattern const &pattern) { return optional_all_of( pattern.attribute_constraints, @@ -229,7 +229,7 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, for (auto const &kv : patternMatch.node_assignment) { Node patternNode = kv.first; Node pcgNode = kv.second; - optional constraintResult = + std::optional constraintResult = satisfies(pcg.at(pcgNode), pattern.value().at(patternNode)); result &= constraintResult.value_or(false); } @@ -237,7 +237,7 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, for (auto const &kv : patternMatch.edge_assignment) { OpenMultiDiEdge patternEdge = kv.first; OpenMultiDiEdge pcgEdge = kv.second; - optional constraintResult = + std::optional constraintResult = satisfies(pcg.at(pcgEdge), pattern.value().at(patternEdge)); result &= constraintResult.value_or(false); } diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc index 7114c2d8ce..f9c6b9a773 100644 --- a/lib/substitutions/src/graph_pattern_match.cc +++ b/lib/substitutions/src/graph_pattern_match.cc @@ -56,7 +56,7 @@ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, } else { assert(is_standard_edge(pattern_edge)); assert(is_standard_edge(graph_edge)); - auto standard_edge = mpark::get(pattern_edge); + auto standard_edge = std::get(pattern_edge); auto divided = edge_splits.at_l(standard_edge); auto divided_graph_edge = split_edge(get(graph_edge)); handle_edge(divided.first, divided_graph_edge.first); @@ -98,7 +98,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, } UpwardOpenMultiDiEdge matched_edge = narrow(graph_matched_edge).value(); - InputMultiDiEdge input_edge = mpark::get(e); + InputMultiDiEdge input_edge = std::get(e); if (match.node_assignment.at_l(input_edge.dst) != get_dst_node(matched_edge)) { return false; @@ -109,7 +109,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, } DownwardOpenMultiDiEdge matched_edge = narrow(graph_matched_edge).value(); - OutputMultiDiEdge output_edge = mpark::get(e); + OutputMultiDiEdge output_edge = std::get(e); if (match.node_assignment.at_l(output_edge.src) != get_src_node(matched_edge)) { return false; @@ -148,7 +148,7 @@ bool src_compare(T const &lhs, T const &rhs) { return get_src_idx(lhs) < get_src_idx(rhs); } -optional +std::optional get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, OpenMultiDiGraphView const &graph, Node const &graph_node) { @@ -170,11 +170,11 @@ optional get_outgoing_edges(pattern, pattern_node); if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { - return nullopt; + return std::nullopt; } if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { - return nullopt; + return std::nullopt; } std::vector incoming_ordered = @@ -198,7 +198,7 @@ optional node_port_mapping.emplace(graph_port, pattern_port); } else { if (pattern_port != node_port_mapping.at(graph_port)) { - return nullopt; + return std::nullopt; } } match.edge_assignment.equate(widen(pattern_edge), @@ -217,7 +217,7 @@ optional node_port_mapping.insert({graph_port, pattern_port}); } else { if (pattern_port != node_port_mapping.at(graph_port)) { - return nullopt; + return std::nullopt; } } match.edge_assignment.equate(widen(pattern_edge), @@ -228,7 +228,7 @@ optional return match; } -optional unsplit_matches( +std::optional unsplit_matches( MultiDiGraphPatternMatch const &prefix, MultiDiGraphPatternMatch const &postfix, bidict> const @@ -248,7 +248,7 @@ optional unsplit_matches( if (output_graph_edge == input_graph_edge) { result.edge_assignment.equate(standard_edge, output_graph_edge); } else { - return nullopt; + return std::nullopt; } } @@ -272,7 +272,7 @@ std::vector std::vector matches; if (is_singleton_pattern(pattern)) { for (Node const &graph_node : get_nodes(graph)) { - optional candidate = + std::optional candidate = get_candidate_singleton_match(pattern, graph, graph_node); if (candidate.has_value() && pattern_matches( @@ -290,7 +290,7 @@ std::vector auto edge_splits = get_edge_splits(pattern, split); for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { - optional unsplit = + std::optional unsplit = unsplit_matches(prefix_match, postfix_match, edge_splits); if (unsplit.has_value()) { matches.push_back(unsplit.value()); diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index 3922b091a7..76533507a3 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -3,25 +3,25 @@ namespace FlexFlow { -optional get_attribute(BatchMatmulAttrs const &p, +std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(CastAttrs const &p, +std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.dtype; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(CombineAttrs const &p, +std::optional get_attribute(CombineAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -29,21 +29,21 @@ optional get_attribute(CombineAttrs const &p, case OperatorAttributeKey::PARALLEL_DIM: return p.combine_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ConcatAttrs const &p, +std::optional get_attribute(ConcatAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(Conv2DAttrs const &p, +std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: @@ -65,43 +65,43 @@ optional get_attribute(Conv2DAttrs const &p, case OperatorAttributeKey::USE_BIAS: return p.use_bias; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementBinaryAttrs const &p, +std::optional get_attribute(ElementBinaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementUnaryAttrs const &p, +std::optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementScalarUnaryAttrs const &p, +std::optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(DropoutAttrs const &p, +std::optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(EmbeddingAttrs const &p, +std::optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: @@ -113,37 +113,37 @@ optional get_attribute(EmbeddingAttrs const &p, case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(FlatAttrs const &p, +std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(GatherAttrs const &p, +std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(LayerNormAttrs const &p, +std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(LinearAttrs const &p, +std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OUT_CHANNELS: @@ -159,11 +159,11 @@ optional get_attribute(LinearAttrs const &p, case OperatorAttributeKey::REGULARIZER: return p.regularizer; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(MultiHeadAttentionAttrs const &p, +std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::NUM_HEADS: @@ -171,11 +171,11 @@ optional get_attribute(MultiHeadAttentionAttrs const &p, case OperatorAttributeKey::USE_BIAS: return p.bias; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(Pool2DAttrs const &p, +std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: @@ -195,19 +195,19 @@ optional get_attribute(Pool2DAttrs const &p, case OperatorAttributeKey::ACTIVATION: return p.activation; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReduceAttrs const &p, +std::optional get_attribute(ReduceAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReductionAttrs const &p, +std::optional get_attribute(ReductionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -215,11 +215,11 @@ optional get_attribute(ReductionAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.reduction_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(RepartitionAttrs const &p, +std::optional get_attribute(RepartitionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -227,11 +227,11 @@ optional get_attribute(RepartitionAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.repartition_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReplicateAttrs const &p, +std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: @@ -239,53 +239,53 @@ optional get_attribute(ReplicateAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.replicate_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReshapeAttrs const &p, +std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(SplitAttrs const &p, +std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(SoftmaxAttrs const &p, +std::optional get_attribute(SoftmaxAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(TopKAttrs const &p, +std::optional get_attribute(TopKAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(TransposeAttrs const &p, +std::optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PERMUTATION: return p.perm; default: - return nullopt; + return std::nullopt; } } @@ -293,7 +293,7 @@ struct GetAttribute { GetAttribute(OperatorAttributeKey key) : key(key) {} template - optional operator()(T const &t) { + std::optional operator()(T const &t) { return get_attribute(t, this->key); } @@ -303,17 +303,17 @@ struct GetAttribute { struct GetOpType { template - optional operator()(T const &t) { + std::optional operator()(T const &t) { return get_op_type(t); } }; -optional get_attribute(PCGOperatorAttrs const &p, +std::optional get_attribute(PCGOperatorAttrs const &p, OperatorAttributeKey key) { if (key == OperatorAttributeKey::OP_TYPE) { - return visit(GetOpType{}, p); + return std::visit(GetOpType{}, p); } - return visit(GetAttribute(key), p); + return std::visit(GetAttribute(key), p); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 31659b88fc..4f6572948a 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -113,49 +113,49 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.emplace(key, value); } assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); - assert(holds_alternative( + assert(std::holds_alternative( assignments.at(OperatorAttributeKey::OP_TYPE))); OperatorType op_type = - get(assignments.at(OperatorAttributeKey::OP_TYPE)); + std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); switch (op_type) { case Op::BATCHMATMUL: return Operator{ BatchMatmulAttrs{ - get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, + std::get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), + std::get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, std::nullopt}; case Op::BATCHNORM: return Operator{ - BatchNormAttrs{get(assignments.at(OperatorAttributeKey::RELU))}, + BatchNormAttrs{std::get(assignments.at(OperatorAttributeKey::RELU))}, std::nullopt}; case Op::CAST: - return Operator{CastAttrs{get( + return Operator{CastAttrs{std::get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, std::nullopt}; case Op::CONCAT: return Operator{ ConcatAttrs{ - get(assignments.at(OperatorAttributeKey::AXIS)), - get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, + std::get(assignments.at(OperatorAttributeKey::AXIS)), + std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, std::nullopt}; case Op::CONV2D: return Operator{ Conv2DAttrs{ - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::KERNEL_H)), - get(assignments.at(OperatorAttributeKey::KERNEL_W)), - get(assignments.at(OperatorAttributeKey::STRIDE_H)), - get(assignments.at(OperatorAttributeKey::STRIDE_W)), - get(assignments.at(OperatorAttributeKey::PADDING_H)), - get(assignments.at(OperatorAttributeKey::PADDING_W)), - get(assignments.at(OperatorAttributeKey::GROUPS)), - get(assignments.at(OperatorAttributeKey::ACTIVATION)), - get(assignments.at(OperatorAttributeKey::USE_BIAS))}, + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + std::get(assignments.at(OperatorAttributeKey::GROUPS)), + std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, std::nullopt}; case Op::DROPOUT: return Operator{ - DropoutAttrs{get(assignments.at(OperatorAttributeKey::RATE)), - get( + DropoutAttrs{std::get(assignments.at(OperatorAttributeKey::RATE)), + std::get( assignments.at(OperatorAttributeKey::SEED))}, std::nullopt}; case Op::EW_ADD: @@ -170,10 +170,10 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, return Operator{ ElementBinaryAttrs{ op_type, - get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - get( + std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), + std::get( assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - get( + std::get( assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, std::nullopt}; case Op::SCALAR_ADD: @@ -184,7 +184,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, return Operator{ ElementScalarUnaryAttrs{ op_type, - get(assignments.at(OperatorAttributeKey::SCALAR))}, + std::get(assignments.at(OperatorAttributeKey::SCALAR))}, std::nullopt}; case Op::EXP: case Op::IDENTITY: @@ -197,63 +197,63 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EMBEDDING: return Operator{ EmbeddingAttrs{ - get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::AGGR)), - get(assignments.at(OperatorAttributeKey::OP_TYPE))}, + std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::AGGR)), + std::get(assignments.at(OperatorAttributeKey::OP_TYPE))}, std::nullopt}; case Op::FLAT: return Operator{FlatAttrs{}, std::nullopt}; case Op::GATHER: return Operator{ - GatherAttrs{get(assignments.at(OperatorAttributeKey::DIM))}, + GatherAttrs{std::get(assignments.at(OperatorAttributeKey::DIM))}, std::nullopt}; case Op::INPUT: return Operator{InputAttrs{}, std::nullopt}; case Op::LAYERNORM: return Operator{ LayerNormAttrs{ - get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), - get( + std::get( assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - get(assignments.at(OperatorAttributeKey::EPSILON))}, + std::get(assignments.at(OperatorAttributeKey::EPSILON))}, std::nullopt}; case Op::LINEAR: return Operator{ LinearAttrs{ - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::USE_BIAS)), - get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - get(assignments.at(OperatorAttributeKey::ACTIVATION)), - get>( + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), + std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), + std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, std::nullopt}; case Op::MULTIHEAD_ATTENTION: return Operator{ MultiHeadAttentionAttrs{ - get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - get(assignments.at(OperatorAttributeKey::VDIM)), - get(assignments.at(OperatorAttributeKey::DROPOUT)), - get(assignments.at(OperatorAttributeKey::BIAS)), - get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, + std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + std::get(assignments.at(OperatorAttributeKey::VDIM)), + std::get(assignments.at(OperatorAttributeKey::DROPOUT)), + std::get(assignments.at(OperatorAttributeKey::BIAS)), + std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), + std::get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, std::nullopt}; case Op::NOOP: return Operator{NoopAttrs{}, std::nullopt}; case Op::POOL2D: return Operator{ Pool2DAttrs{ - get(assignments.at(OperatorAttributeKey::KERNEL_H)), - get(assignments.at(OperatorAttributeKey::KERNEL_W)), - get(assignments.at(OperatorAttributeKey::STRIDE_H)), - get(assignments.at(OperatorAttributeKey::STRIDE_W)), - get(assignments.at(OperatorAttributeKey::PADDING_H)), - get(assignments.at(OperatorAttributeKey::PADDING_W)), - get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - get( + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), + std::get( assignments.at(OperatorAttributeKey::ACTIVATION))}, std::nullopt}; case Op::REDUCE_ARGMAX: @@ -265,65 +265,65 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::REDUCE_SUM: return Operator{ ReduceAttrs{ - get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), op_type, - get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, + std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, std::nullopt}; case Op::REVERSE: - return Operator{ReverseAttrs{get( + return Operator{ReverseAttrs{std::get( assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; case Op::RESHAPE: - return Operator{ReshapeAttrs{get( + return Operator{ReshapeAttrs{std::get( assignments.at(OperatorAttributeKey::SHAPE))}, std::nullopt}; case Op::SPLIT: return Operator{ - SplitAttrs{get>( + SplitAttrs{std::get>( assignments.at(OperatorAttributeKey::SPLITS)), - get(assignments.at(OperatorAttributeKey::AXIS))}, + std::get(assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; case Op::SOFTMAX: - return Operator{SoftmaxAttrs{get( + return Operator{SoftmaxAttrs{std::get( assignments.at(OperatorAttributeKey::DIM))}, std::nullopt}; case Op::TOPK: return Operator{ - TopKAttrs{get(assignments.at(OperatorAttributeKey::K)), - get(assignments.at(OperatorAttributeKey::SORTED))}, + TopKAttrs{std::get(assignments.at(OperatorAttributeKey::K)), + std::get(assignments.at(OperatorAttributeKey::SORTED))}, std::nullopt}; case Op::TRANSPOSE: return Operator{ - TransposeAttrs{get>( + TransposeAttrs{std::get>( assignments.at(OperatorAttributeKey::PERMUTATION))}, std::nullopt}; case Op::COMBINE: return Operator{ CombineAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REDUCTION: return Operator{ ReductionAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPARTITION: return Operator{ RepartitionAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPLICATE: return Operator{ ReplicateAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; default: - mk_runtime_error("Unknown Operator"); + throw mk_runtime_error("Unknown Operator"); } } @@ -435,23 +435,23 @@ SubParallelComputationGraph } for (OpenMultiDiEdge const &output_edge : get_edges(substitution.output_graph_expr.value())) { - if (holds_alternative(output_edge)) { - InputMultiDiEdge e = get(output_edge); + if (std::holds_alternative(output_edge)) { + InputMultiDiEdge e = std::get(output_edge); OpenMultiDiEdge original_edge = match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, original_edge, output_edge); - } else if (holds_alternative(output_edge)) { - OutputMultiDiEdge e = get(output_edge); + } else if (std::holds_alternative(output_edge)) { + OutputMultiDiEdge e = std::get(output_edge); OpenMultiDiEdge original_edge = match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, original_edge, output_edge); } else { - assert(holds_alternative(output_edge)); - MultiDiEdge e = get(output_edge); + assert(std::holds_alternative(output_edge)); + MultiDiEdge e = std::get(output_edge); new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), new_pcg.add_node_port(), node_mapping.at_l(e.src), diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index ed47226297..40ac0a4a1c 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -2,11 +2,11 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H #include "utils/bidict.h" -#include "utils/optional.decl" #include "utils/required_core.h" #include "utils/type_traits_core.h" #include #include +#include namespace FlexFlow { @@ -108,7 +108,7 @@ template std::vector values(C const &c); template -std::unordered_set> +std::unordered_set> items(C const &c); template @@ -291,10 +291,10 @@ template T reversed(T const &t); template -std::vector value_all(std::vector> const &v); +std::vector value_all(std::vector> const &v); template -std::unordered_set value_all(std::unordered_set> const &v); +std::unordered_set value_all(std::unordered_set> const &v); template std::vector subvec(std::vector const &v, diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 750c43abee..1606eb0605 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -674,8 +674,8 @@ std::vector value_all(std::vector> const &v) { } template -std::unordered_set value_all(std::unordered_set> const &v) { - return transform(v, [](optional const &element) { +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [](std::optional const &element) { return unwrap(element, [] { throw mk_runtime_error( "Encountered element without value in call to value_all"); diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot_file.h index 6cdc78f6d4..6cf06d12a7 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot_file.h @@ -10,6 +10,7 @@ #include #include #include +#include template class DotFile { @@ -28,16 +29,16 @@ class DotFile { return s.str(); } bool has_ostream() const { - return this->owned_fstream.has_value() || this->out.has_value(); + return this->owned_fstream.has_value() || this->out != nullptr; } std::ostream &get_ostream() { bool has_owned_stream = this->owned_fstream.has_value(); - bool has_stream_ref = this->out.has_value(); + bool has_stream_ref = (this->out != nullptr); assert(has_owned_stream != has_stream_ref); if (has_owned_stream) { return this->owned_fstream.value(); } else if (has_stream_ref) { - return this->out.value(); + return *this->out; } else { throw std::runtime_error("No ostream value set"); } diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 1ccf881d97..027c3243b9 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -32,12 +32,12 @@ struct OutputLabelledOpenMultiDiGraphView } template - EdgeLabel const &at(variant const &e) const { + EdgeLabel const &at(std::variant const &e) const { return visit([&](auto const &e) -> auto const & { return this->at(e); }, e); } template - EdgeLabel &at(variant const &e) { + EdgeLabel &at(std::variant const &e) { return visit([&](auto const &e) -> auto & { return this->at(e); }, e); } diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index 420f9736d1..feb263335a 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -42,7 +42,7 @@ struct elements_satisfy> : elements_satisfy_impl {}; template -struct is_in_variant; +struct is_in_variant : std::false_type {}; template struct is_in_variant> : std::true_type {}; template diff --git a/lib/utils/src/graph/open_edge.cc b/lib/utils/src/graph/open_edge.cc index b12f87dd1c..1b571d5c6c 100644 --- a/lib/utils/src/graph/open_edge.cc +++ b/lib/utils/src/graph/open_edge.cc @@ -3,15 +3,15 @@ namespace FlexFlow { bool is_input_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } bool is_output_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } bool is_standard_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery( diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 4a6a056d59..f1c9e41005 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -142,7 +142,7 @@ SplitASTNode::SplitASTNode(SplitType type, struct FlattenAST { void add_flattened_child_to_parent(SplitASTNode &parent, SplitAST const &child) { - if (holds_alternative(child)) { + if (std::holds_alternative(child)) { parent.children.push_back(child); return; } @@ -178,11 +178,11 @@ struct ToFinalAST { std::variant operator()(SplitASTNode const &node) { if (node.type == SplitType::SERIAL) { return Serial{transform(node.children, [](SplitAST const &s) { - return narrow>(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } else { return Parallel{transform(node.children, [](SplitAST const &s) { - return narrow>(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 031defd417..1494f0ac27 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -3,61 +3,61 @@ TEST_CASE("widen and narrow functions") { SUBCASE("widen function") { - variant v1 = 42; - variant result = widen>(v1); - variant expected = 42; + std::variant v1 = 42; + std::variant result = widen>(v1); + std::variant expected = 42; CHECK(result == expected); } SUBCASE("narrow function fail") { - variant v2 = + std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - optional> result = narrow>(v2); - optional> expected = float(3.14); + std::optional> result = narrow>(v2); + std::optional> expected = float(3.14); CHECK(!result.has_value()); // result should be empty due to narrowing } SUBCASE("narrow function success") { - variant v2 = + std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - optional> result = narrow>(v2); - optional> expected = 3.14; + std::optional> result = narrow>(v2); + std::optional> expected = 3.14; CHECK(result == expected); // } SUBCASE("cast function") { - variant v3 = 42; - optional> result = cast>(v3); - optional> expected = 42; + std::variant v3 = 42; + std::optional> result = cast>(v3); + std::optional> expected = 42; CHECK(result == expected); } } TEST_CASE("Narrow and cast variants") { - variant original_variant = 42; + std::variant original_variant = 42; // narrow - optional> narrow_result = - narrow>(original_variant); + std::optional> narrow_result = + narrow>(original_variant); CHECK(narrow_result.has_value()); // assert narrow has value // cast - optional> cast_result = - cast>(narrow_result.value()); + std::optional> cast_result = + cast>(narrow_result.value()); CHECK(cast_result.has_value()); // assert cast has value CHECK(get(cast_result.value()) == 42); } TEST_CASE("casting and widening a variant") { - variant smaller_variant = 42; - variant wider_variant; + std::variant smaller_variant = 42; + std::variant wider_variant; // Perform the cast operation - optional> cast_result = cast>(smaller_variant); + std::optional> cast_result = cast>(smaller_variant); REQUIRE(cast_result); // Ensure the cast was successful // Perform the widening operation - wider_variant = widen>(cast_result.value()); + wider_variant = widen>(cast_result.value()); // Check the result CHECK(get(wider_variant) == 42); From d6e10bb0d579f2328e9ce6d355205ca69bc1a6dc Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 22 Mar 2024 17:11:35 -0700 Subject: [PATCH 20/37] Add shell hook for sapling development --- flake.nix | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flake.nix b/flake.nix index 540d0f9a94..595fedac46 100644 --- a/flake.nix +++ b/flake.nix @@ -53,6 +53,10 @@ devShells = rec { ci = mkShell { + shellHook = '' + export PATH="$HOME/ff/.scripts/:$HOME/ff/.modules/proj/bin/:$PATH" + ''; + CMAKE_FLAGS = lib.strings.concatStringsSep " " [ "-DFF_USE_EXTERNAL_LEGION=ON" "-DFF_USE_EXTERNAL_NCCL=ON" From 95fb4cc529a7de643bd4e2af532a2aa88a81f60f Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Fri, 22 Mar 2024 17:24:18 -0700 Subject: [PATCH 21/37] changed from nullopt to std::nullopt --- lib/substitutions/test/src/test_substitution.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index a8f5283eda..552d46a98f 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -89,7 +89,7 @@ TEST_CASE("apply_substitution") { Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n5 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, "linear"}); NodePort p4 = pcg.add_node_port(); NodePort p5 = pcg.add_node_port(); From c09147959670d605f87c8020ef73b0632f0a1faf Mon Sep 17 00:00:00 2001 From: wmdi Date: Sat, 23 Mar 2024 15:48:50 -0400 Subject: [PATCH 22/37] fix cast issue --- lib/compiler/test/CMakeLists.txt | 2 +- .../test/src/test_labelled_open_graph.cc | 220 ++++++++---------- .../utils/graph/labelled/node_labelled.h | 6 +- .../utils/graph/labelled/node_labelled_open.h | 6 +- .../utils/graph/labelled/output_labelled.h | 6 +- .../graph/labelled/output_labelled_open.h | 6 +- .../utils/graph/labelled/standard_labelled.h | 6 +- lib/utils/src/graph/digraph.cc | 6 +- lib/utils/src/graph/multidigraph.cc | 6 +- lib/utils/src/graph/node.cc | 4 +- lib/utils/src/graph/open_graphs.cc | 18 +- lib/utils/src/graph/undirected.cc | 6 +- 12 files changed, 138 insertions(+), 154 deletions(-) diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index 3d35fdabfd..cbd7e233c0 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -2,7 +2,7 @@ ff_add_test_executable( NAME compiler-test SRC_PATTERNS - src/test_labelled_open_graph.cc + src/*.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index a3b6319528..74071160cb 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,113 +4,85 @@ using namespace FlexFlow; -// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { -// auto g = OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); -// Node n2 = g.add_node(); -// Node n3 = g.add_node(); -// Node n4 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); -// NodePort p2 = g.add_node_port(); -// NodePort p3 = g.add_node_port(); -// NodePort p4 = g.add_node_port(); -// NodePort p5 = g.add_node_port(); -// NodePort p6 = g.add_node_port(); -// NodePort p7 = g.add_node_port(); -// NodePort p8 = g.add_node_port(); -// NodePort p9 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; -// MultiDiEdge e1{n2, p2, n0, p0}; -// MultiDiEdge e2{n3, p5, n1, p3}; -// MultiDiEdge e3{n3, p6, n2, p4}; -// MultiDiEdge e4{n4, p8, n3, p7}; -// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// std::unordered_set node_set0{n3, n4}; - -// auto subgraph0 = get_subgraph(g, node_set0); -// auto subgraph1 = get_subgraph(g, node_set0); -// auto subgraph2 = get_subgraph(g, -// node_set0); auto subgraph3 = get_subgraph(g, -// node_set0); - -// CHECK(get_nodes(subgraph0) == node_set0); -// CHECK(get_nodes(subgraph1) == node_set0); -// CHECK(get_nodes(subgraph2) == node_set0); -// CHECK(get_nodes(subgraph3) == node_set0); - -// std::unordered_set input_set{split_edge(e2).second, -// split_edge(e3).second}; -// std::unordered_set output_set{e5}; - -// CHECK(bool(get_open_inputs(subgraph0) == input_set)); -// CHECK(bool(get_open_inputs(subgraph1) == input_set)); -// CHECK(bool(get_open_inputs(subgraph2).empty())); -// CHECK(bool(get_open_inputs(subgraph3).empty())); - -// CHECK(bool(get_open_outputs(subgraph0) == output_set)); -// CHECK(bool(get_open_outputs(subgraph1).empty())); -// CHECK(bool(get_open_outputs(subgraph2) == output_set)); -// CHECK(bool(get_open_outputs(subgraph3).empty())); - -// CHECK(bool(get_edges(subgraph0) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4, e5})); -// CHECK(bool(get_edges(subgraph1) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4})); -// CHECK(bool(get_edges(subgraph2) == -// std::unordered_set{e4, e5})); -// CHECK(bool(get_edges(subgraph3) == -// std::unordered_set{e4})); - -// CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); -// } - -// TEST_CASE("view OutputLabelledMultiDiGraph as open") { -// OutputLabelledMultiDiGraph g = -// OutputLabelledMultiDiGraph::create>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_output(e0, 2); - -// CHECK(get_edges(g).size() == 1); - -// OutputLabelledOpenMultiDiGraphView open_graph = -// view_output_labelled_as_output_labelled_open(g); - -// CHECK(open_graph.at(n0) == 0); -// CHECK(open_graph.at(n1) == 1); -// CHECK(open_graph.at(e0) == 2); - -// // CHECK(get_edges(open_graph).size() == 1); -// } +TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { + auto g = OpenMultiDiGraph::create(); -TEST_CASE("OutputLabelledOpenMultiDiGraph") { - OutputLabelledOpenMultiDiGraph g = - OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); + Node n4 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + NodePort p2 = g.add_node_port(); + NodePort p3 = g.add_node_port(); + NodePort p4 = g.add_node_port(); + NodePort p5 = g.add_node_port(); + NodePort p6 = g.add_node_port(); + NodePort p7 = g.add_node_port(); + NodePort p8 = g.add_node_port(); + NodePort p9 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e1{n2, p2, n0, p0}; + MultiDiEdge e2{n3, p5, n1, p3}; + MultiDiEdge e3{n3, p6, n2, p4}; + MultiDiEdge e4{n4, p8, n3, p7}; + OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + std::unordered_set node_set0{n3, n4}; + + auto subgraph0 = get_subgraph(g, node_set0); + auto subgraph1 = get_subgraph(g, node_set0); + auto subgraph2 = get_subgraph(g, + node_set0); auto subgraph3 = get_subgraph(g, + node_set0); + + CHECK(get_nodes(subgraph0) == node_set0); + CHECK(get_nodes(subgraph1) == node_set0); + CHECK(get_nodes(subgraph2) == node_set0); + CHECK(get_nodes(subgraph3) == node_set0); + + std::unordered_set input_set{split_edge(e2).second, + split_edge(e3).second}; + std::unordered_set output_set{e5}; + + CHECK(bool(get_open_inputs(subgraph0) == input_set)); + CHECK(bool(get_open_inputs(subgraph1) == input_set)); + CHECK(bool(get_open_inputs(subgraph2).empty())); + CHECK(bool(get_open_inputs(subgraph3).empty())); + + CHECK(bool(get_open_outputs(subgraph0) == output_set)); + CHECK(bool(get_open_outputs(subgraph1).empty())); + CHECK(bool(get_open_outputs(subgraph2) == output_set)); + CHECK(bool(get_open_outputs(subgraph3).empty())); + + CHECK(bool(get_edges(subgraph0) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); + CHECK(bool(get_edges(subgraph1) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == + std::unordered_set{e4, e5})); + CHECK(bool(get_edges(subgraph3) == + std::unordered_set{e4})); + + CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); +} + +TEST_CASE("view OutputLabelledMultiDiGraph as open") { + OutputLabelledMultiDiGraph g = + OutputLabelledMultiDiGraph::create>(); Node n0 = g.add_node(0); Node n1 = g.add_node(1); @@ -121,24 +93,36 @@ TEST_CASE("OutputLabelledOpenMultiDiGraph") { MultiDiEdge e0{n1, p1, n0, p0}; g.add_edge(e0); - g.add_label(e0, 2); + g.add_output(e0, 2); - CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); CHECK(get_edges(g).size() == 1); + + OutputLabelledOpenMultiDiGraphView open_graph = + view_output_labelled_as_output_labelled_open(g); + + CHECK(open_graph.at(n0) == 0); + CHECK(open_graph.at(n1) == 1); + CHECK(open_graph.at(e0) == 2); + + CHECK(get_edges(open_graph).size() == 1); } -// TEST_CASE("OpenMultiDiGraph") { -// OpenMultiDiGraph g = OpenMultiDiGraph::create(); +TEST_CASE("OutputLabelledOpenMultiDiGraph") { + OutputLabelledOpenMultiDiGraph g = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); -// MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e0{n1, p1, n0, p0}; -// g.add_edge(e0); + g.add_edge(e0); + g.add_label(e0, 2); -// CHECK(get_edges(g).size() == 1); -// } + CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); + CHECK(get_edges(g).size() == 1); +} diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 9d8874fb14..9aed91f107 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -41,7 +41,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -97,12 +97,12 @@ struct NodeLabelledMultiDiGraph NodeLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 826a8387cb..0fea57cab7 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,7 +55,7 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -121,12 +121,12 @@ struct NodeLabelledOpenMultiDiGraph NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index c6c521c38b..8aab0320b5 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -47,7 +47,7 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -119,12 +119,12 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 24235bee4c..2be56cb477 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -64,7 +64,7 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -145,12 +145,12 @@ struct OutputLabelledOpenMultiDiGraph OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index e1c8e91634..c6d1521471 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -47,7 +47,7 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; @@ -104,12 +104,12 @@ struct LabelledMultiDiGraph LabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } }; diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index dda9eef5e0..ecad1db3f0 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -14,7 +14,7 @@ std::unordered_set } IDiGraphView const &DiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -48,11 +48,11 @@ std::unordered_set } IDiGraph &DiGraph::get_ptr() { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } IDiGraph const &DiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 99a7ea86fa..41ae3e1aa3 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -24,7 +24,7 @@ std::unordered_set } IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -66,12 +66,12 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index 9854afffbf..72caa3136e 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -53,11 +53,11 @@ std::unordered_set Graph::query_nodes(NodeQuery const &q) const { } IGraph const &Graph::get_ptr() const { - return *std::reinterpret_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IGraph &Graph::get_ptr() { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index c32ff6ded5..387dd7e75b 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -56,12 +56,12 @@ NodePort OpenMultiDiGraph::add_node_port() { } IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } IOpenMultiDiGraph const &OpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -77,7 +77,7 @@ std::unordered_set } IUpwardOpenMultiDiGraphView const &UpwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -107,12 +107,12 @@ std::unordered_set UpwardOpenMultiDiGraph::query_edges( } IUpwardOpenMultiDiGraph const &UpwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IUpwardOpenMultiDiGraph &UpwardOpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -129,7 +129,7 @@ std::unordered_set IDownwardOpenMultiDiGraphView const & DownwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -165,12 +165,12 @@ std::unordered_set } IDownwardOpenMultiDiGraph &DownwardOpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } IDownwardOpenMultiDiGraph const &DownwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index ce42cfe22c..b1e8be7f14 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -26,12 +26,12 @@ void UndirectedGraph::remove_edge(UndirectedEdge const &e) { } IUndirectedGraph const &UndirectedGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IUndirectedGraph &UndirectedGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -56,7 +56,7 @@ std::unordered_set } IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } From 54c604af83abf974d90c3c08401f930be72f242f Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 23 Mar 2024 22:10:53 -0700 Subject: [PATCH 23/37] Fix spdlog cmake issue --- cmake/spdlog.cmake | 7 ++++--- flake.nix | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cmake/spdlog.cmake b/cmake/spdlog.cmake index 02021fd51e..5ba1d6cc15 100644 --- a/cmake/spdlog.cmake +++ b/cmake/spdlog.cmake @@ -6,6 +6,7 @@ else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/spdlog) endif() -add_library(spdlog INTERFACE) -target_link_libraries(spdlog INTERFACE spdlog::spdlog) -target_compile_definitions(spdlog INTERFACE SPDLOG_FMT_EXTERNAL) +add_library(ff_spdlog INTERFACE) +target_link_libraries(ff_spdlog INTERFACE spdlog::spdlog) +target_compile_definitions(ff_spdlog INTERFACE SPDLOG_FMT_EXTERNAL) +alias_library(spdlog ff_spdlog) diff --git a/flake.nix b/flake.nix index 595fedac46..d402d3c271 100644 --- a/flake.nix +++ b/flake.nix @@ -103,7 +103,7 @@ buildInputs = builtins.concatLists [ (with pkgs; [ - ccls + clang-tools gh-markdown-preview shellcheck plantuml From 8b914cf1e79655ab0f0ab2f5accbdb8282847480 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 23 Mar 2024 22:17:37 -0700 Subject: [PATCH 24/37] Re-remove submodules --- deps/any | 1 - deps/boost_preprocessor | 1 - deps/googletest | 1 - deps/invoke | 1 - deps/nameof | 1 - deps/optional | 1 - deps/pybind11 | 1 - deps/variant | 1 - 8 files changed, 8 deletions(-) delete mode 160000 deps/any delete mode 160000 deps/boost_preprocessor delete mode 160000 deps/googletest delete mode 160000 deps/invoke delete mode 160000 deps/nameof delete mode 160000 deps/optional delete mode 160000 deps/pybind11 delete mode 160000 deps/variant diff --git a/deps/any b/deps/any deleted file mode 160000 index e88b1bfc16..0000000000 --- a/deps/any +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e88b1bfc160fa9b01e6174dd29c812eeeece3be9 diff --git a/deps/boost_preprocessor b/deps/boost_preprocessor deleted file mode 160000 index 667e87b339..0000000000 --- a/deps/boost_preprocessor +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 667e87b3392db338a919cbe0213979713aca52e3 diff --git a/deps/googletest b/deps/googletest deleted file mode 160000 index 2fe3bd994b..0000000000 --- a/deps/googletest +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2fe3bd994b3189899d93f1d5a881e725e046fdc2 diff --git a/deps/invoke b/deps/invoke deleted file mode 160000 index 2c1eabc2e2..0000000000 --- a/deps/invoke +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2c1eabc2e20ab02961f95c704ff0c0818671ddd1 diff --git a/deps/nameof b/deps/nameof deleted file mode 160000 index 8aeb677413..0000000000 --- a/deps/nameof +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8aeb6774132a01765d8c8679d016b728acd069f5 diff --git a/deps/optional b/deps/optional deleted file mode 160000 index c28fcf74d2..0000000000 --- a/deps/optional +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c28fcf74d207fc667c4ed3dbae4c251ea551c8c1 diff --git a/deps/pybind11 b/deps/pybind11 deleted file mode 160000 index 8de7772cc7..0000000000 --- a/deps/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8de7772cc72daca8e947b79b83fea46214931604 diff --git a/deps/variant b/deps/variant deleted file mode 160000 index 23cb94f027..0000000000 --- a/deps/variant +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 23cb94f027d4ef33bf48133acc2695c7e5c6f1e7 From 189f32303c9b19739837fed921de94785b1c6c10 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 24 Mar 2024 16:24:01 -0400 Subject: [PATCH 25/37] minor fix & fmt --- ...ive_logger.cc => recursive_logger.cc.todo} | 0 ...rsive_logger.h => recursive_logger.h.todo} | 0 .../test/src/test_labelled_open_graph.cc | 36 +++--- lib/compiler/test/src/test_optimal_cost.cc | 2 +- .../include/op-attrs/operator_attrs.h | 4 +- .../include/substitutions/get_attribute.h | 56 ++++----- .../include/substitutions/operator_pattern.h | 31 ++--- lib/substitutions/src/graph_pattern.cc | 20 +-- lib/substitutions/src/operator_attributes.cc | 56 ++++----- lib/substitutions/src/substitution.cc | 117 ++++++++++-------- lib/utils/include/utils/containers.decl.h | 2 +- lib/utils/include/utils/dot_file.h | 2 +- .../utils/graph/labelled/node_labelled.h | 9 +- .../utils/graph/labelled/node_labelled_open.h | 9 +- .../utils/graph/labelled/output_labelled.h | 9 +- .../graph/labelled/output_labelled_open.h | 9 +- .../utils/graph/labelled/standard_labelled.h | 9 +- lib/utils/include/utils/variant.h | 13 +- lib/utils/src/graph/digraph.cc | 3 +- lib/utils/src/graph/multidigraph.cc | 3 +- lib/utils/test/src/test_variant.cc | 15 ++- 21 files changed, 203 insertions(+), 202 deletions(-) rename lib/compiler/src/utils/{recursive_logger.cc => recursive_logger.cc.todo} (100%) rename lib/compiler/src/utils/{recursive_logger.h => recursive_logger.h.todo} (100%) diff --git a/lib/compiler/src/utils/recursive_logger.cc b/lib/compiler/src/utils/recursive_logger.cc.todo similarity index 100% rename from lib/compiler/src/utils/recursive_logger.cc rename to lib/compiler/src/utils/recursive_logger.cc.todo diff --git a/lib/compiler/src/utils/recursive_logger.h b/lib/compiler/src/utils/recursive_logger.h.todo similarity index 100% rename from lib/compiler/src/utils/recursive_logger.h rename to lib/compiler/src/utils/recursive_logger.h.todo diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index 74071160cb..c59d7ee78a 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -1,6 +1,6 @@ #include "compiler/unity_algorithm.h" #include "doctest/doctest.h" -#include "rapidcheck.h" +// #include "rapidcheck.h" using namespace FlexFlow; @@ -42,14 +42,13 @@ TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { auto subgraph0 = get_subgraph(g, node_set0); auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, - node_set0); auto subgraph3 = get_subgraph(g, - node_set0); + auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph3 = get_subgraph(g, node_set0); - CHECK(get_nodes(subgraph0) == node_set0); - CHECK(get_nodes(subgraph1) == node_set0); - CHECK(get_nodes(subgraph2) == node_set0); - CHECK(get_nodes(subgraph3) == node_set0); + CHECK(bool(get_nodes(subgraph0) == node_set0)); + CHECK(bool(get_nodes(subgraph1) == node_set0)); + CHECK(bool(get_nodes(subgraph2) == node_set0)); + CHECK(bool(get_nodes(subgraph3) == node_set0)); std::unordered_set input_set{split_edge(e2).second, split_edge(e3).second}; @@ -73,16 +72,15 @@ TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { split_edge(e2).second, split_edge(e3).second, e4})); CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == - std::unordered_set{e4})); + CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); - CHECK(get_closed_sources(subgraph2) == std::unordered_set{n3}); + CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); } TEST_CASE("view OutputLabelledMultiDiGraph as open") { OutputLabelledMultiDiGraph g = - OutputLabelledMultiDiGraph::create>(); + OutputLabelledMultiDiGraph::create< + UnorderedOutputLabelledMultiDiGraph>(); Node n0 = g.add_node(0); Node n1 = g.add_node(1); @@ -95,14 +93,14 @@ TEST_CASE("view OutputLabelledMultiDiGraph as open") { g.add_edge(e0); g.add_output(e0, 2); - CHECK(get_edges(g).size() == 1); + CHECK(bool(get_edges(g).size() == 1)); OutputLabelledOpenMultiDiGraphView open_graph = view_output_labelled_as_output_labelled_open(g); - CHECK(open_graph.at(n0) == 0); - CHECK(open_graph.at(n1) == 1); - CHECK(open_graph.at(e0) == 2); + CHECK(bool(open_graph.at(n0) == 0)); + CHECK(bool(open_graph.at(n1) == 1)); + CHECK(bool(open_graph.at(e0) == 2)); CHECK(get_edges(open_graph).size() == 1); } @@ -123,6 +121,6 @@ TEST_CASE("OutputLabelledOpenMultiDiGraph") { g.add_edge(e0); g.add_label(e0, 2); - CHECK(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1); - CHECK(get_edges(g).size() == 1); + CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); + CHECK(bool(get_edges(g).size() == 1)); } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 9d90285870..5f5f7d093e 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -34,7 +34,7 @@ TEST_CASE("optimal_cost_0") { Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n1 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, "linear"}); MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 678a049c3b..b63563cd67 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -86,8 +86,8 @@ static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); static_assert(is_valid_opattr::value, ""); -using ParallelOperatorAttrs = - std::variant; +using ParallelOperatorAttrs = std:: + variant; using ComputationGraphAttrs = variant_join>; diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index 7088730c53..0e6dd4c69b 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -8,57 +8,57 @@ namespace FlexFlow { std::optional get_attribute(PCGOperatorAttrs const &, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(CastAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey); -std::optional get_attribute(ElementScalarUnaryAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); +std::optional + get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey); std::optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey); -std::optional get_attribute(MultiHeadAttentionAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); +std::optional + get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); std::optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); std::optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey); + OperatorAttributeKey); // optional get_attribute(FusedParallelOpAttrs const &p, // OperatorAttributeKey); diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 35544f3003..8fc4ebefc2 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -70,21 +70,22 @@ enum class OperatorAttributeKey { NUM_INPUTS }; -using OperatorAttributeValue = std::variant, - stack_vector, - OperatorType, - Activation, - ff_dim_t, - unsigned long long, - AggregateOp, - stack_vector, - std::optional, - PoolOp, - TensorShape, - DataType>; +using OperatorAttributeValue = + std::variant, + stack_vector, + OperatorType, + Activation, + ff_dim_t, + unsigned long long, + AggregateOp, + stack_vector, + std::optional, + PoolOp, + TensorShape, + DataType>; FF_VISITABLE_STRUCT(ListIndexAccess, attribute_key, diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 6f933dd300..296a975626 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -14,7 +14,8 @@ std::optional std::optional const &v) { if (!v.has_value() || !std::holds_alternative>(v.value()) || - !std::holds_alternative>(v.value())) { + !std::holds_alternative>( + v.value())) { return std::nullopt; } @@ -62,7 +63,8 @@ std::optional struct EvaluateOperatorAttributeExpr { EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - std::optional operator()(OperatorAttributeKey const &key) { + std::optional + operator()(OperatorAttributeKey const &key) { return get_attribute(this->attrs.attrs, key); } @@ -148,8 +150,8 @@ std::optional template std::optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - std::optional const &maybe_attribute_value) { + V const &constraint_value, + std::optional const &maybe_attribute_value) { if (!maybe_attribute_value.has_value()) { return std::nullopt; } @@ -167,14 +169,14 @@ std::optional satisfies(ConstraintType constraint_type, } std::optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { + TensorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); return satisfies( constraint.constraint_type, constraint.attribute_value, value); } std::optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { + OperatorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(params, constraint.attribute_expr); OperatorAttributeValue v = value.value(); return satisfies( @@ -183,7 +185,7 @@ std::optional satisfies(Operator const ¶ms, template std::optional optional_all_of(Container const &container, - Function const &func) { + Function const &func) { for (auto const &element : container) { std::optional condition = func(element); if (!condition.has_value()) { @@ -198,7 +200,7 @@ std::optional optional_all_of(Container const &container, } std::optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { + OperatorPattern const &pattern) { return optional_all_of(pattern.attribute_constraints, [&](OperatorAttributeConstraint const &c) { return satisfies(params, c); @@ -206,7 +208,7 @@ std::optional satisfies(Operator const ¶ms, } std::optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { + ParallelTensorPattern const &pattern) { return optional_all_of( pattern.attribute_constraints, [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index 76533507a3..8bd8688194 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -4,7 +4,7 @@ namespace FlexFlow { std::optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -12,7 +12,7 @@ std::optional get_attribute(BatchMatmulAttrs const &p, } std::optional get_attribute(CastAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.dtype; @@ -22,7 +22,7 @@ std::optional get_attribute(CastAttrs const &p, } std::optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.combine_dim; @@ -34,7 +34,7 @@ std::optional get_attribute(CombineAttrs const &p, } std::optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; @@ -44,7 +44,7 @@ std::optional get_attribute(ConcatAttrs const &p, } std::optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -70,7 +70,7 @@ std::optional get_attribute(Conv2DAttrs const &p, } std::optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -78,15 +78,15 @@ std::optional get_attribute(ElementBinaryAttrs const &p, } std::optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; } } -std::optional get_attribute(ElementScalarUnaryAttrs const &p, - OperatorAttributeKey key) { +std::optional + get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -94,7 +94,7 @@ std::optional get_attribute(ElementScalarUnaryAttrs cons } std::optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -102,7 +102,7 @@ std::optional get_attribute(DropoutAttrs const &p, } std::optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.data_type; @@ -118,7 +118,7 @@ std::optional get_attribute(EmbeddingAttrs const &p, } std::optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -126,7 +126,7 @@ std::optional get_attribute(FlatAttrs const &p, } std::optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; @@ -136,7 +136,7 @@ std::optional get_attribute(GatherAttrs const &p, } std::optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -144,7 +144,7 @@ std::optional get_attribute(LayerNormAttrs const &p, } std::optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; @@ -163,8 +163,8 @@ std::optional get_attribute(LinearAttrs const &p, } } -std::optional get_attribute(MultiHeadAttentionAttrs const &p, - OperatorAttributeKey key) { +std::optional + get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::NUM_HEADS: return p.num_heads; @@ -176,7 +176,7 @@ std::optional get_attribute(MultiHeadAttentionAttrs cons } std::optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -200,7 +200,7 @@ std::optional get_attribute(Pool2DAttrs const &p, } std::optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -208,7 +208,7 @@ std::optional get_attribute(ReduceAttrs const &p, } std::optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.reduction_dim; @@ -220,7 +220,7 @@ std::optional get_attribute(ReductionAttrs const &p, } std::optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.repartition_dim; @@ -232,7 +232,7 @@ std::optional get_attribute(RepartitionAttrs const &p, } std::optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.replicate_dim; @@ -244,7 +244,7 @@ std::optional get_attribute(ReplicateAttrs const &p, } std::optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -252,7 +252,7 @@ std::optional get_attribute(ReshapeAttrs const &p, } std::optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; @@ -262,7 +262,7 @@ std::optional get_attribute(SplitAttrs const &p, } std::optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; @@ -272,7 +272,7 @@ std::optional get_attribute(SoftmaxAttrs const &p, } std::optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { default: return std::nullopt; @@ -280,7 +280,7 @@ std::optional get_attribute(TopKAttrs const &p, } std::optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PERMUTATION: return p.perm; @@ -309,7 +309,7 @@ struct GetOpType { }; std::optional get_attribute(PCGOperatorAttrs const &p, - OperatorAttributeKey key) { + OperatorAttributeKey key) { if (key == OperatorAttributeKey::OP_TYPE) { return std::visit(GetOpType{}, p); } diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 4f6572948a..15816185ee 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -120,14 +120,15 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, switch (op_type) { case Op::BATCHMATMUL: return Operator{ - BatchMatmulAttrs{ - std::get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - std::get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, + BatchMatmulAttrs{std::get(assignments.at( + OperatorAttributeKey::A_SEQ_LENGTH_DIM)), + std::get(assignments.at( + OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, std::nullopt}; case Op::BATCHNORM: - return Operator{ - BatchNormAttrs{std::get(assignments.at(OperatorAttributeKey::RELU))}, - std::nullopt}; + return Operator{BatchNormAttrs{std::get( + assignments.at(OperatorAttributeKey::RELU))}, + std::nullopt}; case Op::CAST: return Operator{CastAttrs{std::get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, @@ -135,13 +136,13 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::CONCAT: return Operator{ ConcatAttrs{ - std::get(assignments.at(OperatorAttributeKey::AXIS)), + std::get(assignments.at(OperatorAttributeKey::AXIS)), std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, std::nullopt}; case Op::CONV2D: return Operator{ Conv2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), @@ -149,15 +150,16 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::get(assignments.at(OperatorAttributeKey::PADDING_H)), std::get(assignments.at(OperatorAttributeKey::PADDING_W)), std::get(assignments.at(OperatorAttributeKey::GROUPS)), - std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get( + assignments.at(OperatorAttributeKey::ACTIVATION)), std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, std::nullopt}; case Op::DROPOUT: - return Operator{ - DropoutAttrs{std::get(assignments.at(OperatorAttributeKey::RATE)), - std::get( - assignments.at(OperatorAttributeKey::SEED))}, - std::nullopt}; + return Operator{DropoutAttrs{std::get(assignments.at( + OperatorAttributeKey::RATE)), + std::get(assignments.at( + OperatorAttributeKey::SEED))}, + std::nullopt}; case Op::EW_ADD: case Op::EW_DIV: case Op::EW_EQUAL: @@ -168,13 +170,13 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EW_MUL: case Op::EW_SUB: return Operator{ - ElementBinaryAttrs{ - op_type, - std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - std::get( - assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, + ElementBinaryAttrs{op_type, + std::get(assignments.at( + OperatorAttributeKey::DATA_TYPE)), + std::get(assignments.at( + OperatorAttributeKey::SHOULD_BROADCAST_LHS)), + std::get(assignments.at( + OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, std::nullopt}; case Op::SCALAR_ADD: case Op::SCALAR_FLOOR_DIV: @@ -197,23 +199,24 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EMBEDDING: return Operator{ EmbeddingAttrs{ - std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), std::get(assignments.at(OperatorAttributeKey::AGGR)), - std::get(assignments.at(OperatorAttributeKey::OP_TYPE))}, + std::get( + assignments.at(OperatorAttributeKey::OP_TYPE))}, std::nullopt}; case Op::FLAT: return Operator{FlatAttrs{}, std::nullopt}; case Op::GATHER: - return Operator{ - GatherAttrs{std::get(assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; + return Operator{GatherAttrs{std::get( + assignments.at(OperatorAttributeKey::DIM))}, + std::nullopt}; case Op::INPUT: return Operator{InputAttrs{}, std::nullopt}; case Op::LAYERNORM: return Operator{ LayerNormAttrs{ - std::get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), std::get( assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), @@ -222,31 +225,34 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::LINEAR: return Operator{ LinearAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), - std::get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - std::get(assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get( + assignments.at(OperatorAttributeKey::DATA_TYPE)), + std::get( + assignments.at(OperatorAttributeKey::ACTIVATION)), std::get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, std::nullopt}; case Op::MULTIHEAD_ATTENTION: return Operator{ MultiHeadAttentionAttrs{ - std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), std::get(assignments.at(OperatorAttributeKey::VDIM)), std::get(assignments.at(OperatorAttributeKey::DROPOUT)), std::get(assignments.at(OperatorAttributeKey::BIAS)), std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - std::get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, + std::get( + assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, std::nullopt}; case Op::NOOP: return Operator{NoopAttrs{}, std::nullopt}; case Op::POOL2D: return Operator{ Pool2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), @@ -265,7 +271,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::REDUCE_SUM: return Operator{ ReduceAttrs{ - std::get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), op_type, std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, @@ -280,9 +286,10 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::nullopt}; case Op::SPLIT: return Operator{ - SplitAttrs{std::get>( - assignments.at(OperatorAttributeKey::SPLITS)), - std::get(assignments.at(OperatorAttributeKey::AXIS))}, + SplitAttrs{ + std::get>( + assignments.at(OperatorAttributeKey::SPLITS)), + std::get(assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; case Op::SOFTMAX: return Operator{SoftmaxAttrs{std::get( @@ -290,8 +297,9 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::nullopt}; case Op::TOPK: return Operator{ - TopKAttrs{std::get(assignments.at(OperatorAttributeKey::K)), - std::get(assignments.at(OperatorAttributeKey::SORTED))}, + TopKAttrs{ + std::get(assignments.at(OperatorAttributeKey::K)), + std::get(assignments.at(OperatorAttributeKey::SORTED))}, std::nullopt}; case Op::TRANSPOSE: return Operator{ @@ -299,28 +307,31 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.at(OperatorAttributeKey::PERMUTATION))}, std::nullopt}; case Op::COMBINE: - return Operator{ - CombineAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; + return Operator{CombineAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, + std::nullopt}; case Op::REDUCTION: return Operator{ - ReductionAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + ReductionAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPARTITION: return Operator{ - RepartitionAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + RepartitionAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; case Op::REPLICATE: return Operator{ - ReplicateAttrs{ - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, + ReplicateAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; default: throw mk_runtime_error("Unknown Operator"); diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 40ac0a4a1c..0332a331b2 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -4,9 +4,9 @@ #include "utils/bidict.h" #include "utils/required_core.h" #include "utils/type_traits_core.h" +#include #include #include -#include namespace FlexFlow { diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot_file.h index 6cf06d12a7..1fd9813646 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot_file.h @@ -5,12 +5,12 @@ #include #include #include +#include #include #include #include #include #include -#include template class DotFile { diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index 9aed91f107..856dd4434e 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -41,8 +41,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -97,13 +96,11 @@ struct NodeLabelledMultiDiGraph NodeLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 0fea57cab7..c864c7dacf 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -55,8 +55,7 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -121,13 +120,11 @@ struct NodeLabelledOpenMultiDiGraph NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 8aab0320b5..ac5648c2e1 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -47,8 +47,7 @@ struct OutputLabelledMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -119,13 +118,11 @@ struct OutputLabelledMultiDiGraph private: Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index aaf051c83d..bc4fe3d828 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -64,8 +64,7 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -145,13 +144,11 @@ struct OutputLabelledOpenMultiDiGraph OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index c6d1521471..34dabb5391 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -47,8 +47,7 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -104,13 +103,11 @@ struct LabelledMultiDiGraph LabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::dynamic_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraph); diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index feb263335a..272caaffde 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -194,12 +194,13 @@ auto narrow(Container const &c) { return transform(c, [](VariantIn const &e) { return get(e); }); } -template , VariantIn>::value>> +template < + typename T1, + typename T2, + typename... Trest, + typename VariantIn, + typename = std::enable_if_t< + !is_subeq_variant, VariantIn>::value>> std::optional> narrow(VariantIn const &v) { return visit(VariantNarrowFunctor>{}, v); } diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index ecad1db3f0..bdfe5ff599 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -14,8 +14,7 @@ std::unordered_set } IDiGraphView const &DiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node DiGraph::add_node() { diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 41ae3e1aa3..771e01e573 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -66,8 +66,7 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 1494f0ac27..541ff40920 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -4,7 +4,8 @@ TEST_CASE("widen and narrow functions") { SUBCASE("widen function") { std::variant v1 = 42; - std::variant result = widen>(v1); + std::variant result = + widen>(v1); std::variant expected = 42; CHECK(result == expected); } @@ -12,7 +13,8 @@ TEST_CASE("widen and narrow functions") { SUBCASE("narrow function fail") { std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - std::optional> result = narrow>(v2); + std::optional> result = + narrow>(v2); std::optional> expected = float(3.14); CHECK(!result.has_value()); // result should be empty due to narrowing } @@ -20,14 +22,16 @@ TEST_CASE("widen and narrow functions") { SUBCASE("narrow function success") { std::variant v2 = 3.14; // this is a doule, because 3.14 default to double - std::optional> result = narrow>(v2); + std::optional> result = + narrow>(v2); std::optional> expected = 3.14; CHECK(result == expected); // } SUBCASE("cast function") { std::variant v3 = 42; - std::optional> result = cast>(v3); + std::optional> result = + cast>(v3); std::optional> expected = 42; CHECK(result == expected); } @@ -53,7 +57,8 @@ TEST_CASE("casting and widening a variant") { std::variant wider_variant; // Perform the cast operation - std::optional> cast_result = cast>(smaller_variant); + std::optional> cast_result = + cast>(smaller_variant); REQUIRE(cast_result); // Ensure the cast was successful // Perform the widening operation From d2eb505120fb3a09abfb6811dee106e7f24ba7f9 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 24 Mar 2024 17:17:22 -0400 Subject: [PATCH 26/37] upd tests name to match ci --- lib/compiler/test/CMakeLists.txt | 2 +- lib/substitutions/test/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index cbd7e233c0..13b1fd3b83 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -1,6 +1,6 @@ ff_add_test_executable( NAME - compiler-test + compiler-tests SRC_PATTERNS src/*.cc PRIVATE_INCLUDE diff --git a/lib/substitutions/test/CMakeLists.txt b/lib/substitutions/test/CMakeLists.txt index d7e35ef9af..cfd6383e95 100644 --- a/lib/substitutions/test/CMakeLists.txt +++ b/lib/substitutions/test/CMakeLists.txt @@ -1,6 +1,6 @@ ff_add_test_executable( NAME - substitutions-test + substitutions-tests SRC_PATTERNS src/*.cc PRIVATE_INCLUDE From 371324a505a5f61aca276ecf621e2eb862f2cb5c Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 26 Mar 2024 16:19:08 -0700 Subject: [PATCH 27/37] Add TEST_SUITE declaration to make tests findable by ctest --- .../test/src/test_labelled_open_graph.cc | 240 +++---- lib/compiler/test/src/test_machine_mapping.cc | 36 +- lib/compiler/test/src/test_open_graph.cc | 136 ++-- lib/compiler/test/src/test_optimal_cost.cc | 102 +-- lib/compiler/test/src/test_unity_algorithm.cc | 45 +- .../test/src/test_pattern_matches.cc | 68 +- .../test/src/test_substitution.cc | 240 +++---- lib/utils/test/src/test_algorithms.cc | 408 +++++------ lib/utils/test/src/test_bidict.cc | 100 +-- lib/utils/test/src/test_containers.cc | 651 +++++++++--------- lib/utils/test/src/test_cow_ptr.cc | 66 +- .../src/test_deduplicated_priority_queue.cc | 48 +- lib/utils/test/src/test_disjoint_set.cc | 76 +- lib/utils/test/src/test_dot_file.cc | 76 +- lib/utils/test/src/test_format.cc | 46 +- lib/utils/test/src/test_hash.cc | 20 +- lib/utils/test/src/test_multidigraph.cc | 138 ++-- lib/utils/test/src/test_random_utils.cc | 72 +- lib/utils/test/src/test_sequence.cc | 308 +++++---- lib/utils/test/src/test_stack_map.cc | 88 +-- lib/utils/test/src/test_stack_string.cc | 124 ++-- lib/utils/test/src/test_stack_vector.cc | 142 ++-- lib/utils/test/src/test_tuple.cc | 118 ++-- lib/utils/test/src/test_type_index.cc | 42 +- lib/utils/test/src/test_undirected_graph.cc | 54 +- lib/utils/test/src/test_variant.cc | 110 +-- lib/utils/test/src/test_vector.cc | 46 +- 27 files changed, 1828 insertions(+), 1772 deletions(-) diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index c59d7ee78a..e3498a769a 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -4,123 +4,125 @@ using namespace FlexFlow; -TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { - auto g = OpenMultiDiGraph::create(); - - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - Node n4 = g.add_node(); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - NodePort p2 = g.add_node_port(); - NodePort p3 = g.add_node_port(); - NodePort p4 = g.add_node_port(); - NodePort p5 = g.add_node_port(); - NodePort p6 = g.add_node_port(); - NodePort p7 = g.add_node_port(); - NodePort p8 = g.add_node_port(); - NodePort p9 = g.add_node_port(); - - MultiDiEdge e0{n1, p1, n0, p0}; - MultiDiEdge e1{n2, p2, n0, p0}; - MultiDiEdge e2{n3, p5, n1, p3}; - MultiDiEdge e3{n3, p6, n2, p4}; - MultiDiEdge e4{n4, p8, n3, p7}; - OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - g.add_edge(e5); - - std::unordered_set node_set0{n3, n4}; - - auto subgraph0 = get_subgraph(g, node_set0); - auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, node_set0); - auto subgraph3 = get_subgraph(g, node_set0); - - CHECK(bool(get_nodes(subgraph0) == node_set0)); - CHECK(bool(get_nodes(subgraph1) == node_set0)); - CHECK(bool(get_nodes(subgraph2) == node_set0)); - CHECK(bool(get_nodes(subgraph3) == node_set0)); - - std::unordered_set input_set{split_edge(e2).second, - split_edge(e3).second}; - std::unordered_set output_set{e5}; - - CHECK(bool(get_open_inputs(subgraph0) == input_set)); - CHECK(bool(get_open_inputs(subgraph1) == input_set)); - CHECK(bool(get_open_inputs(subgraph2).empty())); - CHECK(bool(get_open_inputs(subgraph3).empty())); - - CHECK(bool(get_open_outputs(subgraph0) == output_set)); - CHECK(bool(get_open_outputs(subgraph1).empty())); - CHECK(bool(get_open_outputs(subgraph2) == output_set)); - CHECK(bool(get_open_outputs(subgraph3).empty())); - - CHECK(bool(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5})); - CHECK(bool(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4})); - CHECK(bool(get_edges(subgraph2) == - std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); - - CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); -} - -TEST_CASE("view OutputLabelledMultiDiGraph as open") { - OutputLabelledMultiDiGraph g = - OutputLabelledMultiDiGraph::create< - UnorderedOutputLabelledMultiDiGraph>(); - - Node n0 = g.add_node(0); - Node n1 = g.add_node(1); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - - MultiDiEdge e0{n1, p1, n0, p0}; - - g.add_edge(e0); - g.add_output(e0, 2); - - CHECK(bool(get_edges(g).size() == 1)); - - OutputLabelledOpenMultiDiGraphView open_graph = - view_output_labelled_as_output_labelled_open(g); - - CHECK(bool(open_graph.at(n0) == 0)); - CHECK(bool(open_graph.at(n1) == 1)); - CHECK(bool(open_graph.at(e0) == 2)); - - CHECK(get_edges(open_graph).size() == 1); -} - -TEST_CASE("OutputLabelledOpenMultiDiGraph") { - OutputLabelledOpenMultiDiGraph g = - OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - - Node n0 = g.add_node(0); - Node n1 = g.add_node(1); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - - MultiDiEdge e0{n1, p1, n0, p0}; - - g.add_edge(e0); - g.add_label(e0, 2); - - CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); - CHECK(bool(get_edges(g).size() == 1)); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { + auto g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); + Node n4 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + NodePort p2 = g.add_node_port(); + NodePort p3 = g.add_node_port(); + NodePort p4 = g.add_node_port(); + NodePort p5 = g.add_node_port(); + NodePort p6 = g.add_node_port(); + NodePort p7 = g.add_node_port(); + NodePort p8 = g.add_node_port(); + NodePort p9 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e1{n2, p2, n0, p0}; + MultiDiEdge e2{n3, p5, n1, p3}; + MultiDiEdge e3{n3, p6, n2, p4}; + MultiDiEdge e4{n4, p8, n3, p7}; + OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + std::unordered_set node_set0{n3, n4}; + + auto subgraph0 = get_subgraph(g, node_set0); + auto subgraph1 = get_subgraph(g, node_set0); + auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph3 = get_subgraph(g, node_set0); + + CHECK(bool(get_nodes(subgraph0) == node_set0)); + CHECK(bool(get_nodes(subgraph1) == node_set0)); + CHECK(bool(get_nodes(subgraph2) == node_set0)); + CHECK(bool(get_nodes(subgraph3) == node_set0)); + + std::unordered_set input_set{split_edge(e2).second, + split_edge(e3).second}; + std::unordered_set output_set{e5}; + + CHECK(bool(get_open_inputs(subgraph0) == input_set)); + CHECK(bool(get_open_inputs(subgraph1) == input_set)); + CHECK(bool(get_open_inputs(subgraph2).empty())); + CHECK(bool(get_open_inputs(subgraph3).empty())); + + CHECK(bool(get_open_outputs(subgraph0) == output_set)); + CHECK(bool(get_open_outputs(subgraph1).empty())); + CHECK(bool(get_open_outputs(subgraph2) == output_set)); + CHECK(bool(get_open_outputs(subgraph3).empty())); + + CHECK(bool(get_edges(subgraph0) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); + CHECK(bool(get_edges(subgraph1) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == + std::unordered_set{e4, e5})); + CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); + + CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); + } + + TEST_CASE("view OutputLabelledMultiDiGraph as open") { + OutputLabelledMultiDiGraph g = + OutputLabelledMultiDiGraph::create< + UnorderedOutputLabelledMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + + g.add_edge(e0); + g.add_output(e0, 2); + + CHECK(bool(get_edges(g).size() == 1)); + + OutputLabelledOpenMultiDiGraphView open_graph = + view_output_labelled_as_output_labelled_open(g); + + CHECK(bool(open_graph.at(n0) == 0)); + CHECK(bool(open_graph.at(n1) == 1)); + CHECK(bool(open_graph.at(e0) == 2)); + + CHECK(get_edges(open_graph).size() == 1); + } + + TEST_CASE("OutputLabelledOpenMultiDiGraph") { + OutputLabelledOpenMultiDiGraph g = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + + g.add_edge(e0); + g.add_label(e0, 2); + + CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); + CHECK(bool(get_edges(g).size() == 1)); + } } diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc index b2abc6929d..365ed3e1db 100644 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -1,21 +1,23 @@ -// #include "doctest/doctest.h" -// #include "test_generator.h" +#include "doctest/doctest.h" +#include "test_generator.h" -// TEST_CASE("MachineMapping::combine") { -// rc::check([](MachineMapping const &m0, MachineMapping const &m1) { -// RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); +TEST_SUITE(FF_TEST_SUITE) { + // TEST_CASE("MachineMapping::combine") { + // rc::check([](MachineMapping const &m0, MachineMapping const &m1) { + // RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); -// MachineMapping comb = MachineMapping::combine(m0, m1); + // MachineMapping comb = MachineMapping::combine(m0, m1); -// RC_ASSERT(comb.machine_views.size() == -// m0.machine_views.size() + m1.machine_views.size()); -// RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); -// RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); -// }); -// } + // RC_ASSERT(comb.machine_views.size() == + // m0.machine_views.size() + m1.machine_views.size()); + // RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); + // RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); + // }); + // } -// TEST_CASE("OptimalCostResult::infinity") { -// rc::check([](OptimalCostResult const &c) { -// RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); -// }); -// } + // TEST_CASE("OptimalCostResult::infinity") { + // rc::check([](OptimalCostResult const &c) { + // RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); + // }); + // } +} diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc index 7436f213d7..db3630d316 100644 --- a/lib/compiler/test/src/test_open_graph.cc +++ b/lib/compiler/test/src/test_open_graph.cc @@ -4,71 +4,73 @@ using namespace FlexFlow; -TEST_CASE("get_source_sink_open_graph") { - OpenMultiDiGraph g = OpenMultiDiGraph::create(); - - Node n0 = g.add_node(); - NodePort p0 = g.add_node_port(); - InputMultiDiEdge e0{ - n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; - g.add_edge(e0); - - CHECK(bool(get_closed_sources(g) == std::unordered_set{})); - CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - - CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); - CHECK(bool(get_open_sinks(g) == std::unordered_set{})); -} - -TEST_CASE("get_source_sink_open_graph:unconnected") { - OpenMultiDiGraph g = OpenMultiDiGraph::create(); - - Node n0 = g.add_node(); - Node n1 = g.add_node(); - - NodePort p0 = g.add_node_port(); - NodePort p1 = g.add_node_port(); - - InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; - OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; - g.add_edge(e0); - g.add_edge(e1); - - /* - g: ->n0 - n1-> - */ - - CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); - CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - - CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); - CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); -} - -TEST_CASE("get_cut") { - auto g = OpenMultiDiGraph::create(); - - std::vector ns = add_nodes(g, 5); - - MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; - MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; - MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; - MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; - MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; - OutputMultiDiEdge e5{ - ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - g.add_edge(e5); - - GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; - CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, e2})); - - GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; - CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, e4})); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_source_sink_open_graph") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + NodePort p0 = g.add_node_port(); + InputMultiDiEdge e0{ + n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; + g.add_edge(e0); + + CHECK(bool(get_closed_sources(g) == std::unordered_set{})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{})); + } + + TEST_CASE("get_source_sink_open_graph:unconnected") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; + g.add_edge(e0); + g.add_edge(e1); + + /* + g: ->n0 + n1-> + */ + + CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); + } + + TEST_CASE("get_cut") { + auto g = OpenMultiDiGraph::create(); + + std::vector ns = add_nodes(g, 5); + + MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; + MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; + MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; + OutputMultiDiEdge e5{ + ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; + CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, e2})); + + GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; + CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, e4})); + } } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 5f5f7d093e..da303e3ccc 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -4,63 +4,65 @@ using namespace FlexFlow; -// Rapidcheck infrastructures for graphs does not work for now -/* -Tests whether optimal_cost can give a valid result given random PCG, trivial -allowed machine views, trivial cost estimator and random machine specification. -*/ -// TEST_CASE("optimal_cost") { -// auto test_allowed_machine_views = [](Operator const &, -// MachineSpecification const &) { -// return std::unordered_set{make_1d_machine_view(0, 1, 1)}; -// }; -// rc::check([](ParallelComputationGraph const &g, -// MachineSpecification const &machine_spec) { -// OptimalCostCache cached_subgraph_costs; -// OptimalCostResult result = optimal_cost(g, -// test_allowed_machine_views, -// TestCostEstimator{}, -// machine_spec, -// cached_subgraph_costs); -// RC_ASSERT(result.runtime > 0); -// RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); -// }); -// } +TEST_SUITE(FF_TEST_SUITE) { + // Rapidcheck infrastructures for graphs does not work for now + /* + Tests whether optimal_cost can give a valid result given random PCG, trivial + allowed machine views, trivial cost estimator and random machine specification. + */ + // TEST_CASE("optimal_cost") { + // auto test_allowed_machine_views = [](Operator const &, + // MachineSpecification const &) { + // return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + // }; + // rc::check([](ParallelComputationGraph const &g, + // MachineSpecification const &machine_spec) { + // OptimalCostCache cached_subgraph_costs; + // OptimalCostResult result = optimal_cost(g, + // test_allowed_machine_views, + // TestCostEstimator{}, + // machine_spec, + // cached_subgraph_costs); + // RC_ASSERT(result.runtime > 0); + // RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); + // }); + // } -TEST_CASE("optimal_cost_0") { - auto pcg = - OutputLabelledMultiDiGraph::template create< - UnorderedOutputLabelledMultiDiGraph>(); + TEST_CASE("optimal_cost_0") { + auto pcg = + OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>(); - Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); - Node n1 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, - "linear"}); + Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n1 = pcg.add_node(Operator{ + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, + "linear"}); - MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; - pcg.add_edge(e); - pcg.add_output(e, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; + pcg.add_edge(e); + pcg.add_output(e, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); - auto test_allowed_machine_views = [](Operator const &, - MachineSpecification const &) { - return std::unordered_set{ - make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; - }; + auto test_allowed_machine_views = [](Operator const &, + MachineSpecification const &) { + return std::unordered_set{ + make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + }; - CostEstimator estimator = CostEstimator::create(); + CostEstimator estimator = CostEstimator::create(); - MachineSpecification machine_spec{1, 1, 1, 1, 1}; + MachineSpecification machine_spec{1, 1, 1, 1, 1}; - OptimalCostCache cached_results; + OptimalCostCache cached_results; - OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), - test_allowed_machine_views, - estimator, - machine_spec, - cached_results); + OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); - CHECK(bool(result.runtime > 0)); + CHECK(bool(result.runtime > 0)); + } } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index c39b3ef14f..b8fde91c51 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -1,25 +1,28 @@ #include "compiler/unity_algorithm.h" #include "test_cost_estimator.h" #include "test_generator.h" +#include "doctest/doctest.h" -// Rapidcheck does not work for now -// TEST_CASE("graph_optimize") { -// rc::check([](ComputationGraph const &g, -// float alpha, -// int budget, -// float threshold, -// int max_num_ops) { -// Strategy s = graph_optimize( -// g, -// TestCostEstimator{}, -// MachineSpecification{1, 1, 4, 0.1, 0.2}, -// [](Operator const &, MachineSpecification const &) { -// return std::unordered_set{make_1d_machine_view(0, 1, -// 1)}; -// }, -// OptimizerConfig{alpha, budget, threshold, max_num_ops}); -// RC_ASSERT(get_nodes(s.pcg).size() > 0); -// RC_ASSERT(s.machine_mapping.runtime > 0); -// RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); -// }); -// } +TEST_SUITE(FF_TEST_SUITE) { + // Rapidcheck does not work for now + // TEST_CASE("graph_optimize") { + // rc::check([](ComputationGraph const &g, + // float alpha, + // int budget, + // float threshold, + // int max_num_ops) { + // Strategy s = graph_optimize( + // g, + // TestCostEstimator{}, + // MachineSpecification{1, 1, 4, 0.1, 0.2}, + // [](Operator const &, MachineSpecification const &) { + // return std::unordered_set{make_1d_machine_view(0, 1, + // 1)}; + // }, + // OptimizerConfig{alpha, budget, threshold, max_num_ops}); + // RC_ASSERT(get_nodes(s.pcg).size() > 0); + // RC_ASSERT(s.machine_mapping.runtime > 0); + // RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); + // }); + // } +} diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index cc8a5cd5bd..f1abd5c17e 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -62,46 +62,48 @@ struct Arbitrary { // }); // } -TEST_CASE("find_pattern_matches_small") { - MultiDiGraph g = MultiDiGraph::template create(); - - { - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - - MultiDiEdge e0{n1, g.add_node_port(), n0, g.add_node_port()}; - MultiDiEdge e1{n2, g.add_node_port(), n1, g.add_node_port()}; - MultiDiEdge e2{n3, g.add_node_port(), n2, g.add_node_port()}; - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_pattern_matches_small") { + MultiDiGraph g = MultiDiGraph::template create(); - MultiDiGraph sg0 = MultiDiGraph::template create(); + { + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); - { - Node n0 = sg0.add_node(); - Node n1 = sg0.add_node(); + MultiDiEdge e0{n1, g.add_node_port(), n0, g.add_node_port()}; + MultiDiEdge e1{n2, g.add_node_port(), n1, g.add_node_port()}; + MultiDiEdge e2{n3, g.add_node_port(), n2, g.add_node_port()}; - MultiDiEdge e0{n1, sg0.add_node_port(), n0, sg0.add_node_port()}; + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + } - sg0.add_edge(e0); - } + MultiDiGraph sg0 = MultiDiGraph::template create(); + + { + Node n0 = sg0.add_node(); + Node n1 = sg0.add_node(); + + MultiDiEdge e0{n1, sg0.add_node_port(), n0, sg0.add_node_port()}; + + sg0.add_edge(e0); + } - MatchAdditionalCriterion always_true{ - [](Node const &, Node const &) { return true; }, - [](OpenMultiDiEdge const &, OpenMultiDiEdge const &) { return true; }}; + MatchAdditionalCriterion always_true{ + [](Node const &, Node const &) { return true; }, + [](OpenMultiDiEdge const &, OpenMultiDiEdge const &) { return true; }}; - std::vector matches = find_pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); + std::vector matches = find_pattern_matches( + as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); - RC_ASSERT(matches.size() == 3); + RC_ASSERT(matches.size() == 3); - for (MultiDiGraphPatternMatch const &match : matches) { - RC_ASSERT(pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), match, always_true)); + for (MultiDiGraphPatternMatch const &match : matches) { + RC_ASSERT(pattern_matches( + as_openmultidigraph(sg0), as_openmultidigraph(g), match, always_true)); + } } } diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 552d46a98f..86ee087a29 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -5,123 +5,125 @@ using namespace FlexFlow; -TEST_CASE("apply_substitution") { - OperatorPattern operator_pattern_n0{ - std::vector{OperatorAttributeConstraint{ - ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; - - ParallelTensorPattern tensor_pattern_e0{ - std::vector{TensorAttributeConstraint{ - ConstraintType::EQUAL, - ListIndexAccess{TensorAttributeKey::DIM_SIZES, 0}, - 2}}}; - - ParallelTensorPattern tensor_pattern_empty{ - std::vector{}}; - - auto ig = OutputLabelledOpenMultiDiGraph:: - create>(); - Node n0 = ig.add_node(operator_pattern_n0); - NodePort p0 = ig.add_node_port(); - InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; - ig.add_edge(e0); - ig.add_label(e0, tensor_pattern_e0); - - RC_ASSERT(get_nodes(ig).size() == 1); - RC_ASSERT(get_edges(ig).size() == 1); - - GraphPattern input_graph{ig}; - - OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - OperatorAttrAssignment op_ass_n2{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, - {OperatorAttributeKey::OUT_CHANNELS, - OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, - {OperatorAttributeKey::USE_BIAS, - OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, - {OperatorAttributeKey::DATA_TYPE, - OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, - {OperatorAttributeKey::ACTIVATION, - OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, - {OperatorAttributeKey::REGULARIZER, - OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; - - OperatorAttrAssignment op_ass_n3{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - auto og = NodeLabelledOpenMultiDiGraph::create< - UnorderedNodeLabelledOpenMultiDiGraph>(); - Node n1 = og.add_node(op_ass_n1); - Node n2 = og.add_node(op_ass_n2); - Node n3 = og.add_node(op_ass_n3); - NodePort p1 = og.add_node_port(); - NodePort p2 = og.add_node_port(); - NodePort p3 = og.add_node_port(); - InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; - MultiDiEdge e2{n2, p2, n1, p1}; - MultiDiEdge e3{n3, p3, n2, p2}; - og.add_edge(e1); - og.add_edge(e2); - og.add_edge(e3); - OutputGraphExpr output_graph_expr{og}; - - RC_ASSERT(get_nodes(og).size() == 3); - RC_ASSERT(get_edges(og).size() == 3); - - bidict input_mapping; - input_mapping.equate(e0, e1); - bidict output_mapping; - - Substitution substitution{ - input_graph, output_graph_expr, input_mapping, output_mapping}; - - SubParallelComputationGraph pcg = - OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - - Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); - Node n5 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, - "linear"}); - NodePort p4 = pcg.add_node_port(); - NodePort p5 = pcg.add_node_port(); - - MultiDiEdge e4{n5, p5, n4, p4}; - pcg.add_edge(e4); - pcg.add_label(e4, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); - - MatchAdditionalCriterion criterion{ - [&](Node const &pattern_node, Node const &graph_node) { - return operator_satisfies(pcg.at(graph_node), - input_graph.value().at(pattern_node)); - }, - [&](OpenMultiDiEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) { - return parallel_tensor_satisfies(pcg.at(graph_edge), - input_graph.value().at(pattern_edge)); - }}; - - RC_ASSERT(criterion.node_criterion(n0, n5)); - - std::vector matches = - find_pattern_matches(input_graph, pcg, criterion); - - RC_ASSERT(matches.size() == 1); - - SubParallelComputationGraph new_pcg = - apply_substitution(pcg, substitution, matches[0]); - - RC_ASSERT(get_nodes(new_pcg).size() == 4); - RC_ASSERT(get_edges(new_pcg).size() == 3); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("apply_substitution") { + OperatorPattern operator_pattern_n0{ + std::vector{OperatorAttributeConstraint{ + ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; + + ParallelTensorPattern tensor_pattern_e0{ + std::vector{TensorAttributeConstraint{ + ConstraintType::EQUAL, + ListIndexAccess{TensorAttributeKey::DIM_SIZES, 0}, + 2}}}; + + ParallelTensorPattern tensor_pattern_empty{ + std::vector{}}; + + auto ig = OutputLabelledOpenMultiDiGraph:: + create>(); + Node n0 = ig.add_node(operator_pattern_n0); + NodePort p0 = ig.add_node_port(); + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + ig.add_edge(e0); + ig.add_label(e0, tensor_pattern_e0); + + RC_ASSERT(get_nodes(ig).size() == 1); + RC_ASSERT(get_edges(ig).size() == 1); + + GraphPattern input_graph{ig}; + + OperatorAttrAssignment op_ass_n1{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + + OperatorAttrAssignment op_ass_n2{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, + {OperatorAttributeKey::OUT_CHANNELS, + OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, + {OperatorAttributeKey::USE_BIAS, + OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, + {OperatorAttributeKey::DATA_TYPE, + OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, + {OperatorAttributeKey::ACTIVATION, + OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, + {OperatorAttributeKey::REGULARIZER, + OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; + + OperatorAttrAssignment op_ass_n3{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + + auto og = NodeLabelledOpenMultiDiGraph::create< + UnorderedNodeLabelledOpenMultiDiGraph>(); + Node n1 = og.add_node(op_ass_n1); + Node n2 = og.add_node(op_ass_n2); + Node n3 = og.add_node(op_ass_n3); + NodePort p1 = og.add_node_port(); + NodePort p2 = og.add_node_port(); + NodePort p3 = og.add_node_port(); + InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; + MultiDiEdge e2{n2, p2, n1, p1}; + MultiDiEdge e3{n3, p3, n2, p2}; + og.add_edge(e1); + og.add_edge(e2); + og.add_edge(e3); + OutputGraphExpr output_graph_expr{og}; + + RC_ASSERT(get_nodes(og).size() == 3); + RC_ASSERT(get_edges(og).size() == 3); + + bidict input_mapping; + input_mapping.equate(e0, e1); + bidict output_mapping; + + Substitution substitution{ + input_graph, output_graph_expr, input_mapping, output_mapping}; + + SubParallelComputationGraph pcg = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n5 = pcg.add_node(Operator{ + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, + "linear"}); + NodePort p4 = pcg.add_node_port(); + NodePort p5 = pcg.add_node_port(); + + MultiDiEdge e4{n5, p5, n4, p4}; + pcg.add_edge(e4); + pcg.add_label(e4, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); + + MatchAdditionalCriterion criterion{ + [&](Node const &pattern_node, Node const &graph_node) { + return operator_satisfies(pcg.at(graph_node), + input_graph.value().at(pattern_node)); + }, + [&](OpenMultiDiEdge const &pattern_edge, + OpenMultiDiEdge const &graph_edge) { + return parallel_tensor_satisfies(pcg.at(graph_edge), + input_graph.value().at(pattern_edge)); + }}; + + RC_ASSERT(criterion.node_criterion(n0, n5)); + + std::vector matches = + find_pattern_matches(input_graph, pcg, criterion); + + RC_ASSERT(matches.size() == 1); + + SubParallelComputationGraph new_pcg = + apply_substitution(pcg, substitution, matches[0]); + + RC_ASSERT(get_nodes(new_pcg).size() == 4); + RC_ASSERT(get_edges(new_pcg).size() == 3); + } } diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 2e97496b6b..d3236a7b1c 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -12,232 +12,234 @@ using namespace FlexFlow; -TEST_CASE("MultiDiGraph") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector p = add_node_ports(g, 4); - - MultiDiEdge e0{n[3], p[3], n[0], p[0]}; - MultiDiEdge e1{n[2], p[2], n[1], p[0]}; - MultiDiEdge e2{n[3], p[3], n[1], p[1]}; - MultiDiEdge e3{n[3], p[3], n[2], p[2]}; - - std::vector e = {e0, e1, e2, e3}; - - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[1], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{e[3]}); - std::unordered_map> expected_result = - std::unordered_map>{ - {n[1], {}}, - {n[2], {n[1]}}, - {n[3], {n[0], n[1], n[2]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); -} - -TEST_CASE("DiGraph") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - {n[0], n[3]}, - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[2]}, - }; - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[2], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{}); - auto expected_result = std::unordered_map>{ - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - {n[3], {n[0]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - - SUBCASE("get_imm_dominators") { - std::unordered_map> result = get_imm_dominators(g); - - std::unordered_map> expected_result = { - {n[2], n[0]}, - {n[1], n[0]}, - {n[3], n[0]}, - {n[0], nullopt}, - }; - CHECK(result == expected_result); - } - - SUBCASE("get_dominators") { - std::unordered_map> expected = { - {n[0], {n[0]}}, - {n[1], {n[0], n[1]}}, - {n[2], {n[0], n[2]}}, - {n[3], {n[0], n[3]}}, - }; - CHECK(get_dominators(g) == expected); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 4); + std::vector p = add_node_ports(g, 4); + + MultiDiEdge e0{n[3], p[3], n[0], p[0]}; + MultiDiEdge e1{n[2], p[2], n[1], p[0]}; + MultiDiEdge e2{n[3], p[3], n[1], p[1]}; + MultiDiEdge e3{n[3], p[3], n[2], p[2]}; + + std::vector e = {e0, e1, e2, e3}; + + add_edges(g, e); + + CHECK(get_incoming_edges(g, {n[1], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{e[3]}); + std::unordered_map> expected_result = + std::unordered_map>{ + {n[1], {}}, + {n[2], {n[1]}}, + {n[3], {n[0], n[1], n[2]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); } - SUBCASE("get_sinks") { - auto expected = std::unordered_set{n[2], n[3]}; - CHECK(get_sinks(g) == expected); - } + TEST_CASE("DiGraph") { + DiGraph g = DiGraph::create(); - SUBCASE("get_bfs") { - std::unordered_set start_points = std::unordered_set{n[0]}; - auto expected = std::vector{n[0], n[2], n[1], n[3]}; - CHECK(get_bfs_ordering(g, start_points) == expected); - } + std::vector n = add_nodes(g, 4); + std::vector e = { + {n[0], n[3]}, + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[2]}, + }; + add_edges(g, e); - SUBCASE("get_predecessors") { - std::unordered_map> expected_result = { + CHECK(get_incoming_edges(g, {n[2], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{}); + auto expected_result = std::unordered_map>{ {n[1], {n[0]}}, {n[2], {n[0], n[1]}}, + {n[3], {n[0]}}, }; - CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); - } -} + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); -TEST_CASE("traversal") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 5); - std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; - add_edges(g, edges); - - CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); - CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == true); - CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); - CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); - - SUBCASE("with root") { - g.add_edge({n[3], n[2]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); + SUBCASE("get_imm_dominators") { + std::unordered_map> result = get_imm_dominators(g); + + std::unordered_map> expected_result = { + {n[2], n[0]}, + {n[1], n[0]}, + {n[3], n[0]}, + {n[0], nullopt}, + }; + CHECK(result == expected_result); + } + + SUBCASE("get_dominators") { + std::unordered_map> expected = { + {n[0], {n[0]}}, + {n[1], {n[0], n[1]}}, + {n[2], {n[0], n[2]}}, + {n[3], {n[0], n[3]}}, + }; + CHECK(get_dominators(g) == expected); + } + + SUBCASE("get_sinks") { + auto expected = std::unordered_set{n[2], n[3]}; + CHECK(get_sinks(g) == expected); + } + + SUBCASE("get_bfs") { + std::unordered_set start_points = std::unordered_set{n[0]}; + auto expected = std::vector{n[0], n[2], n[1], n[3]}; + CHECK(get_bfs_ordering(g, start_points) == expected); + } + + SUBCASE("get_predecessors") { + std::unordered_map> expected_result = { + {n[1], {n[0]}}, + {n[2], {n[0], n[1]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); + } } - SUBCASE("without root") { - g.add_edge({n[3], n[0]}); + TEST_CASE("traversal") { + DiGraph g = DiGraph::create(); + std::vector const n = add_nodes(g, 5); + std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; + add_edges(g, edges); - CHECK(get_dfs_ordering(g, {n[0]}) == + CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); + CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - SUBCASE("nonlinear") { - g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + CHECK(get_bfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == true); + CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); + CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); + + SUBCASE("with root") { + g.add_edge({n[3], n[2]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == false); + } + + SUBCASE("without root") { + g.add_edge({n[3], n[0]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == false); + } + SUBCASE("nonlinear") { + g.add_edge({n[1], n[3]}); + CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + } + + SUBCASE("not connected") { + g.remove_edge({n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + } } - SUBCASE("not connected") { - g.remove_edge({n[2], n[3]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + TEST_CASE("bfs") { + DiGraph g = DiGraph::create(); + std::vector const n = add_nodes(g, 7); + + std::vector e = { + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[6]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}, + {n[5], n[6]}, + {n[6], n[0]}, + }; + + add_edges(g, e); + + std::vector ordering = get_bfs_ordering(g, {n[0]}); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).has_value()); + CHECK(index_of(ordering, n[r]).has_value()); + CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value()); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + + CHECK_BEFORE(1, 3); + CHECK_BEFORE(1, 6); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(2, 6); + + CHECK_BEFORE(3, 4); + CHECK_BEFORE(6, 4); + + CHECK_BEFORE(4, 5); } -} -TEST_CASE("bfs") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 7); - - std::vector e = { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }; - - add_edges(g, e); - - std::vector ordering = get_bfs_ordering(g, {n[0]}); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value()); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - - CHECK_BEFORE(1, 3); - CHECK_BEFORE(1, 6); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(2, 6); - - CHECK_BEFORE(3, 4); - CHECK_BEFORE(6, 4); - - CHECK_BEFORE(4, 5); -} + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {{n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[5]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).has_value()); + CHECK(index_of(ordering, n[r]).has_value()); + CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); + }; -TEST_CASE("get_topological_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - std::vector edges = {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}; - add_edges(g, edges); - std::vector ordering = get_topological_ordering(g); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - CHECK_BEFORE(1, 5); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(3, 4); - CHECK_BEFORE(4, 5); -} + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } -TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 4); + std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; - CHECK(get_connected_components(g) == expected_components); -} + CHECK(get_connected_components(g) == expected_components); + } -TEST_CASE("get_weakly_connected_components") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); + TEST_CASE("get_weakly_connected_components") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; - CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); + CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); - CHECK(get_weakly_connected_components(g) == expected_components); + CHECK(get_weakly_connected_components(g) == expected_components); + } } diff --git a/lib/utils/test/src/test_bidict.cc b/lib/utils/test/src/test_bidict.cc index 6c288089b6..afc32b3658 100644 --- a/lib/utils/test/src/test_bidict.cc +++ b/lib/utils/test/src/test_bidict.cc @@ -3,61 +3,63 @@ using namespace FlexFlow; -TEST_CASE("bidict") { - bidict dict; - dict.equate(1, "one"); - dict.equate(2, "two"); - - // Test the equate() function - SUBCASE("Equate") { - CHECK(dict.at_l(1) == "one"); - CHECK(dict.at_r("one") == 1); - CHECK(dict.at_l(2) == "two"); - CHECK(dict.at_r("two") == 2); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("bidict") { + bidict dict; + dict.equate(1, "one"); + dict.equate(2, "two"); - // Test the erase_l() function - SUBCASE("EraseL") { - dict.erase_l(1); - CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); - CHECK(dict.at_r("two") == 2); - } + // Test the equate() function + SUBCASE("Equate") { + CHECK(dict.at_l(1) == "one"); + CHECK(dict.at_r("one") == 1); + CHECK(dict.at_l(2) == "two"); + CHECK(dict.at_r("two") == 2); + } - // Test the erase_r() function - SUBCASE("EraseR") { - dict.erase_r("one"); - CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); - CHECK(dict.at_l(2) == "two"); - } + // Test the erase_l() function + SUBCASE("EraseL") { + dict.erase_l(1); + CHECK(dict.size() == 1); + CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); + CHECK(dict.at_r("two") == 2); + } - // Test the reversed() function - SUBCASE("Reversed") { - bidict reversed_dict = dict.reversed(); - CHECK(reversed_dict.at_l("one") == 1); - CHECK(reversed_dict.at_r(2) == "two"); - } + // Test the erase_r() function + SUBCASE("EraseR") { + dict.erase_r("one"); + CHECK(dict.size() == 1); + CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); + CHECK(dict.at_l(2) == "two"); + } - // Test the size() function - SUBCASE("Size") { - CHECK(dict.size() == 2); - } + // Test the reversed() function + SUBCASE("Reversed") { + bidict reversed_dict = dict.reversed(); + CHECK(reversed_dict.at_l("one") == 1); + CHECK(reversed_dict.at_r(2) == "two"); + } - SUBCASE("implicitly convert to std::unordered_map") { - std::unordered_map res = dict; - std::unordered_map expected = {{1, "one"}, {2, "two"}}; - CHECK(res == expected); - } + // Test the size() function + SUBCASE("Size") { + CHECK(dict.size() == 2); + } - SUBCASE("begin") { - auto it = dict.begin(); - CHECK(it->first == 2); - CHECK(it->second == "two"); - } + SUBCASE("implicitly convert to std::unordered_map") { + std::unordered_map res = dict; + std::unordered_map expected = {{1, "one"}, {2, "two"}}; + CHECK(res == expected); + } + + SUBCASE("begin") { + auto it = dict.begin(); + CHECK(it->first == 2); + CHECK(it->second == "two"); + } - SUBCASE("end") { - auto it = dict.end(); - CHECK(it == dict.end()); + SUBCASE("end") { + auto it = dict.end(); + CHECK(it == dict.end()); + } } } diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index 8c37abf877..f6ac6e2d42 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -5,384 +5,387 @@ #include using namespace FlexFlow; -TEST_CASE("join_strings") { - std::vector const v = {"Hello", "world", "!"}; - CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); -} -TEST_CASE("join_strings with container") { - std::vector const v = {"Hello", "world"}; - CHECK(join_strings(v, " ") == "Hello world"); -} +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("join_strings") { + std::vector const v = {"Hello", "world", "!"}; + CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); + } -TEST_CASE("find") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(find(v, 3) != v.cend()); - CHECK(find(v, 6) == v.cend()); -} + TEST_CASE("join_strings with container") { + std::vector const v = {"Hello", "world"}; + CHECK(join_strings(v, " ") == "Hello world"); + } -TEST_CASE("sum") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(sum(v) == 15); -} + TEST_CASE("find") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(find(v, 3) != v.cend()); + CHECK(find(v, 6) == v.cend()); + } -TEST_CASE("sum with condition") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; // Sum of even numbers only - CHECK(sum_where(v, condition) == 6); -} + TEST_CASE("sum") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(sum(v) == 15); + } -TEST_CASE("product") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(product(v) == 120); -} + TEST_CASE("sum with condition") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; // Sum of even numbers only + CHECK(sum_where(v, condition) == 6); + } -TEST_CASE("product_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Product of even numbers only - CHECK(product_where(v, condition) == 8); -} + TEST_CASE("product") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(product(v) == 120); + } -TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); -} + TEST_CASE("product_where") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { + return x % 2 == 0; + }; // Product of even numbers only + CHECK(product_where(v, condition) == 8); + } -TEST_CASE("contains_key") { - std::unordered_map m = { - {"one", 1}, {"two", 2}, {"three", 3}}; - CHECK(contains_key(m, "one")); - CHECK(!contains_key(m, "four")); -} + TEST_CASE("contains") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } -TEST_CASE("map_keys") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; // Mapping function - auto result = map_keys(m, f); - CHECK(result.size() == 2); - CHECK(result[1] == "one"); - CHECK(result[4] == "two"); -} + TEST_CASE("contains_key") { + std::unordered_map m = { + {"one", 1}, {"two", 2}, {"three", 3}}; + CHECK(contains_key(m, "one")); + CHECK(!contains_key(m, "four")); + } -TEST_CASE("filter_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = [](int x) { return x % 2 == 1; }; // Filtering function - std::unordered_map result = filter_keys(m, f); - std::unordered_map expected = {{1, "one"}, {3, "three"}}; - CHECK(result == expected); -} + TEST_CASE("map_keys") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](int x) { return x * x; }; // Mapping function + auto result = map_keys(m, f); + CHECK(result.size() == 2); + CHECK(result[1] == "one"); + CHECK(result[4] == "two"); + } -TEST_CASE("map_values") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](std::string const &s) { return s.size(); }; // Mapping function - std::unordered_map result = map_values(m, f); - std::unordered_map expected = {{1, 3}, {2, 3}}; - CHECK(result == expected); -} + TEST_CASE("filter_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = [](int x) { return x % 2 == 1; }; // Filtering function + std::unordered_map result = filter_keys(m, f); + std::unordered_map expected = {{1, "one"}, {3, "three"}}; + CHECK(result == expected); + } -TEST_CASE("keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {3, 2, 1}; - CHECK(result == expected); -} + TEST_CASE("map_values") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](std::string const &s) { return s.size(); }; // Mapping function + std::unordered_map result = map_values(m, f); + std::unordered_map expected = {{1, 3}, {2, 3}}; + CHECK(result == expected); + } -TEST_CASE("values") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::vector result = values(m); - std::vector expected = {"three", "two", "one"}; - CHECK(result == expected); -} + TEST_CASE("keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set result = keys(m); + std::unordered_set expected = {3, 2, 1}; + CHECK(result == expected); + } -// TEST_CASE("items") { -// std::unordered_map m = {{1, std::string("one")}, {2, -// std::string("two")}, {3,std::string("three")}}; -// std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; - std::unordered_set result = unique(v); - std::unordered_set expected = {1, 2, 3}; - CHECK(result == expected); -} + TEST_CASE("values") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::vector result = values(m); + std::vector expected = {"three", "two", "one"}; + CHECK(result == expected); + } -TEST_CASE("without_order") { - std::vector v = {1, 4, 6, 4, 6}; - std::unordered_set expected = {1, 4, 6}; - CHECK(without_order(v) == expected); -} + // TEST_CASE("items") { + // std::unordered_map m = {{1, std::string("one")}, {2, + // std::string("two")}, {3,std::string("three")}}; + // std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; + std::unordered_set result = unique(v); + std::unordered_set expected = {1, 2, 3}; + CHECK(result == expected); + } -TEST_CASE("index_of") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3) == 2); - CHECK(!index_of(v, 6).has_value()); -} + TEST_CASE("without_order") { + std::vector v = {1, 4, 6, 4, 6}; + std::unordered_set expected = {1, 4, 6}; + CHECK(without_order(v) == expected); + } -TEST_CASE("intersection") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {2, 3, 4}; - std::unordered_set result = intersection(l, r); - std::unordered_set expected = {2, 3}; - CHECK(result == expected); -} + TEST_CASE("index_of") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(index_of(v, 3) == 2); + CHECK(!index_of(v, 6).has_value()); + } -TEST_CASE("are_disjoint") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {4, 5, 6}; - CHECK(are_disjoint(l, r)); - r.insert(3); - CHECK_FALSE(are_disjoint(l, r)); -} + TEST_CASE("intersection") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {2, 3, 4}; + std::unordered_set result = intersection(l, r); + std::unordered_set expected = {2, 3}; + CHECK(result == expected); + } -TEST_CASE("restrict_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set mask = {2, 3, 4}; - std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); -} + TEST_CASE("are_disjoint") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {4, 5, 6}; + CHECK(are_disjoint(l, r)); + r.insert(3); + CHECK_FALSE(are_disjoint(l, r)); + } -TEST_CASE("merge_maps(unordered_map)") { - std::unordered_map lhs = {{1, "one"}, {2, "two"}}; - std::unordered_map rhs = {{3, "three"}, {4, "four"}}; - std::unordered_map result = merge_maps(lhs, rhs); - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); -} + TEST_CASE("restrict_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set mask = {2, 3, 4}; + std::unordered_map result = restrict_keys(m, mask); + std::unordered_map expected = {{2, "two"}, {3, "three"}}; + CHECK(result == expected); + } -TEST_CASE("merge_maps(bidict)") { - std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; - std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; - std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; - std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; - bidict lhs{fwd_map1, bwd_map1}; - bidict rhs{fwd_map2, bwd_map2}; - - std::unordered_map result = - merge_maps(lhs, rhs); // impicit conversion - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); -} + TEST_CASE("merge_maps(unordered_map)") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{3, "three"}, {4, "four"}}; + std::unordered_map result = merge_maps(lhs, rhs); + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK(result == expected); + } -TEST_CASE("lookup_in") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = lookup_in(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - CHECK(f(3) == "three"); -} + TEST_CASE("merge_maps(bidict)") { + std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; + std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; + std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; + std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; + bidict lhs{fwd_map1, bwd_map1}; + bidict rhs{fwd_map2, bwd_map2}; + + std::unordered_map result = + merge_maps(lhs, rhs); // impicit conversion + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK(result == expected); + } -TEST_CASE("lookup_in_l") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_l(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); -} + TEST_CASE("lookup_in") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = lookup_in(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + CHECK(f(3) == "three"); + } -TEST_CASE("lookup_in_r") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_r(m); - CHECK(f("one") == 1); - CHECK(f("two") == 2); -} + TEST_CASE("lookup_in_l") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_l(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + } -TEST_CASE("set_union") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {2, 3, 4}; - std::unordered_set result = set_union(s1, s2); - std::unordered_set expected = {1, 2, 3, 4}; - CHECK(result == expected); -} + TEST_CASE("lookup_in_r") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_r(m); + CHECK(f("one") == 1); + CHECK(f("two") == 2); + } -TEST_CASE("is_subseteq_of") { - std::unordered_set s1 = {1, 2}; - std::unordered_set s2 = {1, 2, 3}; - CHECK(is_subseteq_of(s1, s2) == true); - CHECK(is_subseteq_of(s2, s1) == false); - CHECK(is_subseteq_of(s1, s1) == true); - CHECK(is_subseteq_of(s2, s2) == true); -} + TEST_CASE("set_union") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {2, 3, 4}; + std::unordered_set result = set_union(s1, s2); + std::unordered_set expected = {1, 2, 3, 4}; + CHECK(result == expected); + } -TEST_CASE("is_superseteq_of") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {1, 2}; - CHECK(is_supserseteq_of(s1, s2) == true); - CHECK(is_supserseteq_of(s2, s1) == false); -} + TEST_CASE("is_subseteq_of") { + std::unordered_set s1 = {1, 2}; + std::unordered_set s2 = {1, 2, 3}; + CHECK(is_subseteq_of(s1, s2) == true); + CHECK(is_subseteq_of(s2, s1) == false); + CHECK(is_subseteq_of(s1, s1) == true); + CHECK(is_subseteq_of(s2, s2) == true); + } -TEST_CASE("get_only") { - std::unordered_set s = {42}; - CHECK(get_only(s) == 42); -} + TEST_CASE("is_superseteq_of") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {1, 2}; + CHECK(is_supserseteq_of(s1, s2) == true); + CHECK(is_supserseteq_of(s2, s1) == false); + } -TEST_CASE("get_first") { - std::unordered_set s = {1, 2, 3}; - CHECK(s.count(get_first(s)) == 1); -} + TEST_CASE("get_only") { + std::unordered_set s = {42}; + CHECK(get_only(s) == 42); + } -TEST_CASE("extend") { - std::vector v = {1, 2, 3}; - std::unordered_set s = {4, 5, 6}; - extend(v, s); - CHECK(v.size() == 6); - std::vector expected = {1, 2, 3, 6, 5, 4}; - CHECK(v == expected); -} + TEST_CASE("get_first") { + std::unordered_set s = {1, 2, 3}; + CHECK(s.count(get_first(s)) == 1); + } -TEST_CASE("all_of") { - std::vector v = {2, 4, 6, 8}; - CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); - CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); -} + TEST_CASE("extend") { + std::vector v = {1, 2, 3}; + std::unordered_set s = {4, 5, 6}; + extend(v, s); + CHECK(v.size() == 6); + std::vector expected = {1, 2, 3, 6, 5, 4}; + CHECK(v == expected); + } -TEST_CASE("count") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); - CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); -} + TEST_CASE("all_of") { + std::vector v = {2, 4, 6, 8}; + CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); + CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); + } -TEST_CASE("are_all_same") { - std::vector v1 = {2, 2, 2, 2}; - std::vector v2 = {1, 2, 3, 4}; - CHECK(are_all_same(v1) == true); - CHECK(are_all_same(v2) == false); -} + TEST_CASE("count") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); + CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); + } -TEST_CASE("vector_transform") { - std::vector v = {1, 2, 3}; - auto result = vector_transform([](int x) { return x * 2; }, v); - CHECK(result == std::vector({2, 4, 6})); -} + TEST_CASE("are_all_same") { + std::vector v1 = {2, 2, 2, 2}; + std::vector v2 = {1, 2, 3, 4}; + CHECK(are_all_same(v1) == true); + CHECK(are_all_same(v2) == false); + } -TEST_CASE("as_vector") { - std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); - CHECK(result == std::vector({3, 2, 1})); -} + TEST_CASE("vector_transform") { + std::vector v = {1, 2, 3}; + auto result = vector_transform([](int x) { return x * 2; }, v); + CHECK(result == std::vector({2, 4, 6})); + } -TEST_CASE("transform_vector") { - std::vector v = {1, 2, 3}; - auto result = transform(v, [](int x) { return x * 2; }); - CHECK(result == std::vector({2, 4, 6})); -} + TEST_CASE("as_vector") { + std::unordered_set s = {1, 2, 3}; + std::vector result = as_vector(s); + CHECK(result == std::vector({3, 2, 1})); + } -TEST_CASE("transform_unordered_set") { - std::unordered_set s = {1, 2, 3}; - auto result = transform(s, [](int x) { return x * 2; }); - CHECK(result == std::unordered_set({2, 4, 6})); -} + TEST_CASE("transform_vector") { + std::vector v = {1, 2, 3}; + auto result = transform(v, [](int x) { return x * 2; }); + CHECK(result == std::vector({2, 4, 6})); + } -TEST_CASE("transform_string") { - std::string s = "abc"; - auto result = transform(s, ::toupper); - CHECK(result == "ABC"); -} + TEST_CASE("transform_unordered_set") { + std::unordered_set s = {1, 2, 3}; + auto result = transform(s, [](int x) { return x * 2; }); + CHECK(result == std::unordered_set({2, 4, 6})); + } -TEST_CASE("repeat") { - int ctr = 0; - std::vector result = repeat(5, [&] { return ctr++; }); + TEST_CASE("transform_string") { + std::string s = "abc"; + auto result = transform(s, ::toupper); + CHECK(result == "ABC"); + } - CHECK(result == std::vector{0, 1, 2, 3, 4}); -} + TEST_CASE("repeat") { + int ctr = 0; + std::vector result = repeat(5, [&] { return ctr++; }); -TEST_CASE("Testing the 'enumerate' function") { - std::unordered_set input_set = {1, 2, 3, 4, 5}; - std::unordered_map result = enumerate(input_set); - std::unordered_map expected = { - {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; - CHECK(result == expected); -} + CHECK(result == std::vector{0, 1, 2, 3, 4}); + } -TEST_CASE("Testing the 'maximum' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - auto result = maximum(input_vec); + TEST_CASE("Testing the 'enumerate' function") { + std::unordered_set input_set = {1, 2, 3, 4, 5}; + std::unordered_map result = enumerate(input_set); + std::unordered_map expected = { + {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; + CHECK(result == expected); + } - // Checking the maximum is as expected - REQUIRE(result == 5); -} + TEST_CASE("Testing the 'maximum' function") { + std::vector input_vec = {1, 2, 3, 4, 5}; + auto result = maximum(input_vec); -TEST_CASE("Testing the 'reversed' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - std::vector result = reversed(input_vec); - std::vector expected = {5, 4, 3, 2, 1}; + // Checking the maximum is as expected + REQUIRE(result == 5); + } - // Checking the reversed sequence is as expected - CHECK(result == expected); -} + TEST_CASE("Testing the 'reversed' function") { + std::vector input_vec = {1, 2, 3, 4, 5}; + std::vector result = reversed(input_vec); + std::vector expected = {5, 4, 3, 2, 1}; + + // Checking the reversed sequence is as expected + CHECK(result == expected); + } -TEST_CASE("Testing sorted_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); - CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); + TEST_CASE("Testing sorted_by function") { + std::unordered_set s = {5, 2, 3, 4, 1}; + auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); + CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); - std::unordered_set s2 = {-5, -1, -3, -2, -4}; - auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); - CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); -} + std::unordered_set s2 = {-5, -1, -3, -2, -4}; + auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); + CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); + } -TEST_CASE("Testing compare_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - std::vector result = - sorted_by(s, compare_by([](int i) { return (-i); })); - CHECK(result == std::vector{5, 4, 3, 2, 1}); -} + TEST_CASE("Testing compare_by function") { + std::unordered_set s = {5, 2, 3, 4, 1}; + std::vector result = + sorted_by(s, compare_by([](int i) { return (-i); })); + CHECK(result == std::vector{5, 4, 3, 2, 1}); + } -TEST_CASE("Testing vector_split function") { - std::vector v = {1, 2, 3, 4, 5}; - auto result = vector_split(v, 2); - std::vector prefix = result.first; - std::vector postfix = result.second; - CHECK(prefix == std::vector({1, 2})); - CHECK(postfix == std::vector({3, 4, 5})); -} + TEST_CASE("Testing vector_split function") { + std::vector v = {1, 2, 3, 4, 5}; + auto result = vector_split(v, 2); + std::vector prefix = result.first; + std::vector postfix = result.second; + CHECK(prefix == std::vector({1, 2})); + CHECK(postfix == std::vector({3, 4, 5})); + } -TEST_CASE("Testing value_all function") { - std::vector> v = {1, 2, 3, 4, 5}; - auto value_all_v = value_all(v); - CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); -} + TEST_CASE("Testing value_all function") { + std::vector> v = {1, 2, 3, 4, 5}; + auto value_all_v = value_all(v); + CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); + } -TEST_CASE("Testing subvec function") { - std::vector v = {1, 2, 3, 4, 5}; - auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); + TEST_CASE("Testing subvec function") { + std::vector v = {1, 2, 3, 4, 5}; + auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); - CHECK(subvec_v == std::vector({2, 3, 4})); + CHECK(subvec_v == std::vector({2, 3, 4})); - auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); - CHECK(subvec_v2 == std::vector({1, 2, 3})); -} + auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); + CHECK(subvec_v2 == std::vector({1, 2, 3})); + } -auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } } + return factors; + }; + + // Example for vector + TEST_CASE("Test for flatmap function on vectors") { + std::vector v = {2, 3, 4, 5}; + auto result = flatmap(v, get_factors); + CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); } - return factors; -}; - -// Example for vector -TEST_CASE("Test for flatmap function on vectors") { - std::vector v = {2, 3, 4, 5}; - auto result = flatmap(v, get_factors); - CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); } diff --git a/lib/utils/test/src/test_cow_ptr.cc b/lib/utils/test/src/test_cow_ptr.cc index 62406bddec..de573d0c9b 100644 --- a/lib/utils/test/src/test_cow_ptr.cc +++ b/lib/utils/test/src/test_cow_ptr.cc @@ -22,39 +22,41 @@ struct TestObjectDerived : public TestObject { } }; -TEST_CASE("cow_ptr_t constructor") { - std::shared_ptr sp = std::make_shared(1); - cow_ptr_t p1(sp); - cow_ptr_t p2(std::make_shared(3)); - cow_ptr_t p3(TestObject(2)); - cow_ptr_t p4(p3); - cow_ptr_t p5 = p1; - CHECK(p1->x == 1); - CHECK(p2->x == 3); - CHECK(p3->x == 2); - CHECK(p4->x == p3->x); - CHECK(p5->x == p1->x); -} +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cow_ptr_t constructor") { + std::shared_ptr sp = std::make_shared(1); + cow_ptr_t p1(sp); + cow_ptr_t p2(std::make_shared(3)); + cow_ptr_t p3(TestObject(2)); + cow_ptr_t p4(p3); + cow_ptr_t p5 = p1; + CHECK(p1->x == 1); + CHECK(p2->x == 3); + CHECK(p3->x == 2); + CHECK(p4->x == p3->x); + CHECK(p5->x == p1->x); + } -TEST_CASE("cow_ptr_t copy") { - cow_ptr_t p1(std::make_shared(1)); - cow_ptr_t p2(std::make_shared(2)); - p1 = p2; - CHECK(p1->x == p2->x); -} + TEST_CASE("cow_ptr_t copy") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(std::make_shared(2)); + p1 = p2; + CHECK(p1->x == p2->x); + } -TEST_CASE("cow_ptr_t cast") { - cow_ptr_t p1(std::make_shared(1, 2)); - cow_ptr_t p2(p1); - CHECK(p2->x == 1); -} + TEST_CASE("cow_ptr_t cast") { + cow_ptr_t p1(std::make_shared(1, 2)); + cow_ptr_t p2(p1); + CHECK(p2->x == 1); + } -TEST_CASE("cow_ptr_t get_mutable") { - cow_ptr_t p1(std::make_shared(1)); - cow_ptr_t p2(p1); - p1.get_mutable()->x = 3; - CHECK(p1->x == 3); - CHECK(p2->x == 1); - p2.get_mutable()->x = 2; - CHECK(p1->x == 3); + TEST_CASE("cow_ptr_t get_mutable") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(p1); + p1.get_mutable()->x = 3; + CHECK(p1->x == 3); + CHECK(p2->x == 1); + p2.get_mutable()->x = 2; + CHECK(p1->x == 3); + } } diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/test_deduplicated_priority_queue.cc index a5c97fa0f8..66cfd395bc 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/test_deduplicated_priority_queue.cc @@ -1,34 +1,36 @@ #include "test/utils/doctest.h" #include "utils/deduplicated_priority_queue.h" -TEST_CASE("DeduplicatedPriorityQueue push and pop") { - DeduplicatedPriorityQueue queue; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DeduplicatedPriorityQueue push and pop") { + DeduplicatedPriorityQueue queue; - SUBCASE("Push elements") { - queue.push(5); - queue.push(2); - queue.push(7); - queue.push(2); + SUBCASE("Push elements") { + queue.push(5); + queue.push(2); + queue.push(7); + queue.push(2); - CHECK(queue.size() == 3); - CHECK(queue.top() == 7); - CHECK_FALSE(queue.empty()); - } + CHECK(queue.size() == 3); + CHECK(queue.top() == 7); + CHECK_FALSE(queue.empty()); + } - SUBCASE("Pop elements") { - queue.push(5); - queue.push(2); - queue.push(7); + SUBCASE("Pop elements") { + queue.push(5); + queue.push(2); + queue.push(7); - queue.pop(); - CHECK(queue.size() == 2); - CHECK(queue.top() == 5); + queue.pop(); + CHECK(queue.size() == 2); + CHECK(queue.top() == 5); - queue.pop(); - CHECK(queue.size() == 1); - CHECK(queue.top() == 2); + queue.pop(); + CHECK(queue.size() == 1); + CHECK(queue.top() == 2); - queue.pop(); - CHECK(queue.empty()); + queue.pop(); + CHECK(queue.empty()); + } } } diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index fe2c4bae33..8bcf2e533f 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -16,53 +16,55 @@ std::string generate_element(int seed) { return "Element" + std::to_string(seed); } -TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { - disjoint_set> ds; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { + disjoint_set> ds; - SUBCASE("SingleElementSets") { - optional element = generate_element(1); - CHECK(ds.find(element) == element); + SUBCASE("SingleElementSets") { + optional element = generate_element(1); + CHECK(ds.find(element) == element); - element = generate_element(2); - CHECK(ds.find(element) == element); - } + element = generate_element(2); + CHECK(ds.find(element) == element); + } - SUBCASE("UnionAndFind") { - optional element1 = generate_element(1); - optional element2 = generate_element(2); - optional element3 = generate_element(3); - optional element4 = generate_element(4); + SUBCASE("UnionAndFind") { + optional element1 = generate_element(1); + optional element2 = generate_element(2); + optional element3 = generate_element(3); + optional element4 = generate_element(4); - ds.m_union(element1, element2); - CHECK(ds.find(element1) == ds.find(element2)); + ds.m_union(element1, element2); + CHECK(ds.find(element1) == ds.find(element2)); - ds.m_union(element3, element4); - CHECK(ds.find(element3) == ds.find(element4)); + ds.m_union(element3, element4); + CHECK(ds.find(element3) == ds.find(element4)); - ds.m_union(element1, element3); - CHECK(ds.find(element1) == ds.find(element3)); - CHECK(ds.find(element2) == ds.find(element4)); - CHECK(ds.find(element1) == ds.find(element2)); - CHECK(ds.find(element1) == ds.find(element4)); + ds.m_union(element1, element3); + CHECK(ds.find(element1) == ds.find(element3)); + CHECK(ds.find(element2) == ds.find(element4)); + CHECK(ds.find(element1) == ds.find(element2)); + CHECK(ds.find(element1) == ds.find(element4)); + } } -} -TEST_CASE_TEMPLATE("DisjointSetMapping", T, int, std::string) { - disjoint_set ds; - ds.m_union(1, 2); - ds.m_union(3, 4); - ds.m_union(1, 4); - ds.m_union(5, 6); + TEST_CASE_TEMPLATE("DisjointSetMapping", T, int, std::string) { + disjoint_set ds; + ds.m_union(1, 2); + ds.m_union(3, 4); + ds.m_union(1, 4); + ds.m_union(5, 6); - std::map, optional, OptionalComparator> - expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; + std::map, optional, OptionalComparator> + expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; - std::map, optional, OptionalComparator> mapping = - ds.get_mapping(); + std::map, optional, OptionalComparator> mapping = + ds.get_mapping(); - for (auto const &kv : mapping) { - CHECK( - *kv.second == - *expectedMapping[kv.first]); // Compare the values inside the optionals + for (auto const &kv : mapping) { + CHECK( + *kv.second == + *expectedMapping[kv.first]); // Compare the values inside the optionals + } } } diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/test_dot_file.cc index a65265afbd..ed4c32bb1c 100644 --- a/lib/utils/test/src/test_dot_file.cc +++ b/lib/utils/test/src/test_dot_file.cc @@ -2,67 +2,68 @@ #include "utils/dot_file.h" #include -TEST_CASE("DotFile") { - std::ostringstream oss; - DotFile dotFile(oss); - SUBCASE("add_node") { - dotFile.add_node("A", {{"shape", "circle"}, {"label", "Node A"}}); - dotFile.add_node("B", {{"shape", "rectangle"}, {"label", "Node B"}}); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DotFile") { + std::ostringstream oss; + DotFile dotFile(oss); + SUBCASE("add_node") { + dotFile.add_node("A", {{"shape", "circle"}, {"label", "Node A"}}); + dotFile.add_node("B", {{"shape", "rectangle"}, {"label", "Node B"}}); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { node0 [label=Node A,shape=circle]; node1 [label=Node B,shape=rectangle]; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_edge") { - dotFile.add_edge("A", "B"); - dotFile.add_edge("B", "C"); + SUBCASE("add_edge") { + dotFile.add_edge("A", "B"); + dotFile.add_edge("B", "C"); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { node0 -> node1; node1 -> node2; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_record_node") { - RecordFormatter rf; + SUBCASE("add_record_node") { + RecordFormatter rf; - rf << "Field1"; - rf << 42; - rf << "Field2"; - rf << float(3.14); + rf << "Field1"; + rf << 42; + rf << "Field2"; + rf << float(3.14); - dotFile.add_record_node("A", rf); + dotFile.add_record_node("A", rf); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = - R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = + R"EXPECTED_OUTPUT(digraph taskgraph { node0 [label="{ Field1 | 42 | Field2 | 3.140000e+00 }",shape=record]; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_node_to_subgraph") { - size_t subgraph1 = dotFile.add_subgraph(); - size_t subgraph2 = dotFile.add_subgraph(subgraph1); + SUBCASE("add_node_to_subgraph") { + size_t subgraph1 = dotFile.add_subgraph(); + size_t subgraph2 = dotFile.add_subgraph(subgraph1); - dotFile.add_node_to_subgraph("A", subgraph1); - dotFile.add_node_to_subgraph("B", subgraph2); + dotFile.add_node_to_subgraph("A", subgraph1); + dotFile.add_node_to_subgraph("B", subgraph2); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { subgraph cluster_0 { node1; node0; @@ -72,6 +73,7 @@ node1; } })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); + CHECK(oss.str() == expectedOutput); + } } } diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/test_format.cc index 2f653c85af..eeed2eae81 100644 --- a/lib/utils/test/src/test_format.cc +++ b/lib/utils/test/src/test_format.cc @@ -7,32 +7,34 @@ std::string formatRecord(RecordFormatter const &formatter) { return oss.str(); } -TEST_CASE("RecordFormatter") { - RecordFormatter formatter; - SUBCASE("Appending string") { - formatter << "Hello"; - formatter << "World"; - CHECK(formatRecord(formatter) == "{ Hello | World }"); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RecordFormatter") { + RecordFormatter formatter; + SUBCASE("Appending string") { + formatter << "Hello"; + formatter << "World"; + CHECK(formatRecord(formatter) == "{ Hello | World }"); + } - SUBCASE("Appending integer and float") { - formatter << 42; - formatter << 3.14f; - CHECK(formatRecord(formatter) == "{ 42 | 3.140000e+00 }"); - } + SUBCASE("Appending integer and float") { + formatter << 42; + formatter << 3.14f; + CHECK(formatRecord(formatter) == "{ 42 | 3.140000e+00 }"); + } - SUBCASE("Appending another RecordFormatter") { - RecordFormatter subFormatter; - subFormatter << "Sub"; - subFormatter << "Formatter"; + SUBCASE("Appending another RecordFormatter") { + RecordFormatter subFormatter; + subFormatter << "Sub"; + subFormatter << "Formatter"; - RecordFormatter formatter; - formatter << "Hello"; - formatter << subFormatter; + RecordFormatter formatter; + formatter << "Hello"; + formatter << subFormatter; - std::ostringstream oss; - oss << formatter; + std::ostringstream oss; + oss << formatter; - CHECK(formatRecord(formatter) == "{ Hello | { Sub | Formatter } }"); + CHECK(formatRecord(formatter) == "{ Hello | { Sub | Formatter } }"); + } } } diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc index f0d907b741..b38c43fe30 100644 --- a/lib/utils/test/src/test_hash.cc +++ b/lib/utils/test/src/test_hash.cc @@ -3,16 +3,18 @@ using namespace FlexFlow; -TEST_CASE("hash:unordered_map") { - std::unordered_map map1{{1, 2}}; - std::unordered_map map2{{1, 2}, {3, 4}}; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("hash:unordered_map") { + std::unordered_map map1{{1, 2}}; + std::unordered_map map2{{1, 2}, {3, 4}}; - size_t hash1 = get_std_hash(map1); - size_t hash2 = get_std_hash(map2); + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); - CHECK(hash1 != hash2); + CHECK(hash1 != hash2); - map1.insert({1, 2}); - hash1 = get_std_hash(map1); - CHECK(hash1 == hash2); + map1.insert({1, 2}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } } diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 944ff0b7ca..91631f0391 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -5,86 +5,88 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { - MultiDiGraph g = MultiDiGraph::create(); - - std::vector n = repeat(3, [&] { return g.add_node(); }); - std::vector p = repeat(3, [&] { return g.add_node_port(); }); - - std::vector e = {{n[1], p[1], n[0], p[0]}, - {n[2], p[2], n[0], p[0]}, - {n[0], p[0], n[2], p[2]}, - {n[1], p[1], n[2], p[2]}}; - for (MultiDiEdge const &edge : e) { - g.add_edge(edge); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { + MultiDiGraph g = MultiDiGraph::create(); + + std::vector n = repeat(3, [&] { return g.add_node(); }); + std::vector p = repeat(3, [&] { return g.add_node_port(); }); - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[0], n[1], n[2]}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[0], e[1], e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( - {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( - {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( - {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( - {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all() - .with_src_nodes({n[1]}) - .with_dst_nodes({n[2]}) - .with_src_idxs({p[1]}) - .with_dst_idxs({p[2]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == - std::unordered_set{e[1]}); - - SUBCASE("remove node") { - g.remove_node_unsafe(n[0]); + std::vector e = {{n[1], p[1], n[0], p[0]}, + {n[2], p[2], n[0], p[0]}, + {n[0], p[0], n[2], p[2]}, + {n[1], p[1], n[2], p[2]}}; + for (MultiDiEdge const &edge : e) { + g.add_edge(edge); + } CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[1], n[2]}); + std::unordered_set{n[0], n[1], n[2]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[2], e[3]}); + std::unordered_set{e[0], e[1], e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( + {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( + {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( + {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( + {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all() + .with_src_nodes({n[1]}) + .with_dst_nodes({n[2]}) + .with_src_idxs({p[1]}) + .with_dst_idxs({p[2]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == + std::unordered_set{e[1]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == - std::unordered_set{e[2]}); + SUBCASE("remove node") { + g.remove_node_unsafe(n[0]); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == - std::unordered_set{e[2]}); - } + CHECK(g.query_nodes(NodeQuery::all()) == + std::unordered_set{n[1], n[2]}); - SUBCASE("remove_edge") { - g.remove_edge(e[0]); + CHECK(g.query_edges(MultiDiEdgeQuery::all()) == + std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( - {n[1]})) == std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == + std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == - std::unordered_set{e[1]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == + std::unordered_set{e[2]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == + std::unordered_set{e[2]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges( + MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( + {n[1]})) == std::unordered_set{}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == + std::unordered_set{e[1]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + } } } diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc index dd7c320d85..88a566a198 100644 --- a/lib/utils/test/src/test_random_utils.cc +++ b/lib/utils/test/src/test_random_utils.cc @@ -14,52 +14,54 @@ void checkProbabilities(std::vector const &counts, } } -TEST_CASE("select_random") { - std::vector values = {1, 2, 3, 4, 5}; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("select_random") { + std::vector values = {1, 2, 3, 4, 5}; - SUBCASE("Select random value") { - int result = select_random(values); + SUBCASE("Select random value") { + int result = select_random(values); - CHECK(std::find(values.begin(), values.end(), result) != values.end()); - } + CHECK(std::find(values.begin(), values.end(), result) != values.end()); + } - SUBCASE("Invalid arguments") { - std::vector weights = {0.1f, 0.3f, 0.2f}; - CHECK(select_random(values, weights) == 2); + SUBCASE("Invalid arguments") { + std::vector weights = {0.1f, 0.3f, 0.2f}; + CHECK(select_random(values, weights) == 2); + } } -} -TEST_CASE("select_random - Weighted Random Selection") { - SUBCASE("Test with equal weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + TEST_CASE("select_random - Weighted Random Selection") { + SUBCASE("Test with equal weights") { + std::vector values = {1, 2, 3, 4, 5}; + std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; + std::vector counts(values.size(), 0); + int const numIterations = 10000; + for (int i = 0; i < numIterations; i++) { + int selected = select_random(values, weights); + counts[selected - 1]++; + } + + checkProbabilities(counts, numIterations, weights, values.size()); } - checkProbabilities(counts, numIterations, weights, values.size()); - } + SUBCASE("Test with different weights") { + std::vector values = {1, 2, 3, 4, 5}; + std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; - SUBCASE("Test with different weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; + std::vector counts(values.size(), 0); + int const numIterations = 10000; + for (int i = 0; i < numIterations; i++) { + int selected = select_random(values, weights); + counts[selected - 1]++; + } - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; - } + float totalWeight = 0.0f; + for (float weight : weights) { + totalWeight += weight; + } - float totalWeight = 0.0f; - for (float weight : weights) { - totalWeight += weight; + checkProbabilities(counts, numIterations, weights, totalWeight); } - - checkProbabilities(counts, numIterations, weights, totalWeight); } } diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/test_sequence.cc index 576271a858..ee72febe05 100644 --- a/lib/utils/test/src/test_sequence.cc +++ b/lib/utils/test/src/test_sequence.cc @@ -3,169 +3,171 @@ using namespace FlexFlow; -TEST_CASE("seq_head") { - SUBCASE("seq_head with non-empty sequence") { - using Seq = seq<1, 2, 3, 4>; - constexpr int result = seq_head::value; - CHECK(result == 1); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("seq_head") { + SUBCASE("seq_head with non-empty sequence") { + using Seq = seq<1, 2, 3, 4>; + constexpr int result = seq_head::value; + CHECK(result == 1); + } + + SUBCASE("seq_head with empty sequence") { + using Seq = seq<>; + constexpr int result = seq_head::value; + CHECK(result == -1); + } } - SUBCASE("seq_head with empty sequence") { - using Seq = seq<>; - constexpr int result = seq_head::value; - CHECK(result == -1); + TEST_CASE("seq_tail") { + SUBCASE("seq_tail with non-empty sequence") { + using Seq = seq<1, 2, 3, 4>; + using ResultType = typename seq_tail::type; + using ExpectedType = seq<2, 3, 4>; + CHECK(std::is_same::value); + } + + SUBCASE("seq_tail with empty sequence") { + using Seq = seq<>; + using ResultType = typename seq_tail::type; + using ExpectedType = seq<>; + CHECK(std::is_same::value); + } } -} -TEST_CASE("seq_tail") { - SUBCASE("seq_tail with non-empty sequence") { - using Seq = seq<1, 2, 3, 4>; - using ResultType = typename seq_tail::type; - using ExpectedType = seq<2, 3, 4>; + TEST_CASE("seq_prepend") { + using ResultType = typename FlexFlow::seq_prepend<1, 2, 3>::type; + using ExpectedType = FlexFlow::seq<1, 2, 3>; CHECK(std::is_same::value); } - SUBCASE("seq_tail with empty sequence") { - using Seq = seq<>; - using ResultType = typename seq_tail::type; - using ExpectedType = seq<>; + TEST_CASE("seq_append") { + using Seq = seq<1, 2, 3>; + using ResultType = typename seq_append::type; + using ExpectedType = seq<1, 2, 3, 4>; CHECK(std::is_same::value); } -} -TEST_CASE("seq_prepend") { - using ResultType = typename FlexFlow::seq_prepend<1, 2, 3>::type; - using ExpectedType = FlexFlow::seq<1, 2, 3>; - CHECK(std::is_same::value); -} - -TEST_CASE("seq_append") { - using Seq = seq<1, 2, 3>; - using ResultType = typename seq_append::type; - using ExpectedType = seq<1, 2, 3, 4>; - CHECK(std::is_same::value); -} + TEST_CASE("seq_count") { + using ResultType = seq_count_t<5>; + using ExpectedType = seq<1, 2, 3, 4, 5>; + CHECK(!std::is_same::value); + } -TEST_CASE("seq_count") { - using ResultType = seq_count_t<5>; - using ExpectedType = seq<1, 2, 3, 4, 5>; - CHECK(!std::is_same::value); -} + TEST_CASE("seq_enumerate_args") { + using Args = std::tuple; + using ResultType = seq_enumerate_args_t; + using ExpectedType = seq<0, 1, 2>; + CHECK(std::is_same::value); + } -TEST_CASE("seq_enumerate_args") { - using Args = std::tuple; - using ResultType = seq_enumerate_args_t; - using ExpectedType = seq<0, 1, 2>; - CHECK(std::is_same::value); + // template + // int square(std::integral_constant) { + // return X * X; + // } + + // TEST_CASE("seq_select") { + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_select(square, 1, seq<1, 2, 3>); + // CHECK(result == 4); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_select(square, 3, Seq{}), std::runtime_error); + // } + // } + + // TEST_CASE("seq_get") { + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_get(square, 2, Seq{}); + // CHECK(result == 9); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_get(square, 3, Seq{}), std::runtime_error); + // } + // } + + // TEST_CASE("seq_get") { + // struct F { + // template + // int operator()(std::integral_constant) const { + // return X * X; + // } + // }; + + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_get(F{}, 2, Seq{}); + // CHECK(result == 9); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_get(F{}, 3, Seq{}), std::runtime_error); + // } + // } + + // struct F { + // template + // struct type { + // using result = std::integral_constant; + // }; + // }; + + // TEST_CASE("seq_transform_type") { + // using Seq = seq<1, 2, 3>; + // using ResultType = seq_transform_type_t; + // using ExpectedType = std::tuple, + // std::integral_constant, + // std::integral_constant>; + // CHECK(std::is_same::value); + // } + + // TEST_CASE("seq_transform") { + // struct F { + // template + // int operator()(std::integral_constant) { + // return X * X; + // } + // }; + + // using Seq = seq<1, 2, 3>; + // auto result = seq_transform(F{}, Seq{}); + // std::tuple expected{1, 4, 9}; + // CHECK(result == expected); + // } + + // TEST_CASE("seq_select") { + // struct F { + // template + // tl::optional operator()(std::integral_constant) { + // if (X % 2 == 0) { + // return X; + // } else { + // return tl::nullopt; + // } + // } + // }; + + // using Seq = seq<1, 2, 3, 4, 5>; + // int result = seq_select(F{}, Seq{}); + // CHECK(result == 2); + // } + + // TEST_CASE("seq_get") { + // struct F { + // template + // int operator()(std::integral_constant) { + // return X * X; + // } + // }; + + // using Seq = seq<1, 2, 3, 4, 5>; + // int result = seq_get(F{}, 3, Seq{}); + // CHECK(result == 16); + // } } - -// template -// int square(std::integral_constant) { -// return X * X; -// } - -// TEST_CASE("seq_select") { -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_select(square, 1, seq<1, 2, 3>); -// CHECK(result == 4); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_select(square, 3, Seq{}), std::runtime_error); -// } -// } - -// TEST_CASE("seq_get") { -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_get(square, 2, Seq{}); -// CHECK(result == 9); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_get(square, 3, Seq{}), std::runtime_error); -// } -// } - -// TEST_CASE("seq_get") { -// struct F { -// template -// int operator()(std::integral_constant) const { -// return X * X; -// } -// }; - -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_get(F{}, 2, Seq{}); -// CHECK(result == 9); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_get(F{}, 3, Seq{}), std::runtime_error); -// } -// } - -// struct F { -// template -// struct type { -// using result = std::integral_constant; -// }; -// }; - -// TEST_CASE("seq_transform_type") { -// using Seq = seq<1, 2, 3>; -// using ResultType = seq_transform_type_t; -// using ExpectedType = std::tuple, -// std::integral_constant, -// std::integral_constant>; -// CHECK(std::is_same::value); -// } - -// TEST_CASE("seq_transform") { -// struct F { -// template -// int operator()(std::integral_constant) { -// return X * X; -// } -// }; - -// using Seq = seq<1, 2, 3>; -// auto result = seq_transform(F{}, Seq{}); -// std::tuple expected{1, 4, 9}; -// CHECK(result == expected); -// } - -// TEST_CASE("seq_select") { -// struct F { -// template -// tl::optional operator()(std::integral_constant) { -// if (X % 2 == 0) { -// return X; -// } else { -// return tl::nullopt; -// } -// } -// }; - -// using Seq = seq<1, 2, 3, 4, 5>; -// int result = seq_select(F{}, Seq{}); -// CHECK(result == 2); -// } - -// TEST_CASE("seq_get") { -// struct F { -// template -// int operator()(std::integral_constant) { -// return X * X; -// } -// }; - -// using Seq = seq<1, 2, 3, 4, 5>; -// int result = seq_get(F{}, 3, Seq{}); -// CHECK(result == 16); -// } diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/test_stack_map.cc index 11d332afa4..21c1b07d1b 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/test_stack_map.cc @@ -3,48 +3,50 @@ using namespace FlexFlow; -TEST_CASE("stack_map") { - stack_map map; - // Test the [] operator to insert and access elements - SUBCASE("BracketOperator") { - map[1] = 10; - map[2] = 20; - - CHECK(map[1] == 10); - CHECK(map[2] == 20); - } - - // Test the insert() function - SUBCASE("Insert") { - map.insert(1, 10); - map.insert(2, 20); - - CHECK(map[1] == 10); - CHECK(map[2] == 20); - } - - // Test the at() function to access elements - SUBCASE("At") { - map[1] = 10; - map[2] = 20; - - CHECK(map.at(1) == 10); - CHECK(map.at(2) == 20); - CHECK(map.at(1) != 20); - // Test const version of at() function - stack_map const &const_map = map; - CHECK(const_map.at(1) == 10); - CHECK(const_map.at(2) == 20); - } - - // Test the begin() and end() functions for iterator - SUBCASE("Iterator") { - map[1] = 10; - map[2] = 20; - map[3] = 30; - - std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; - std::vector> actual = map; - CHECK(actual == expected); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("stack_map") { + stack_map map; + // Test the [] operator to insert and access elements + SUBCASE("BracketOperator") { + map[1] = 10; + map[2] = 20; + + CHECK(map[1] == 10); + CHECK(map[2] == 20); + } + + // Test the insert() function + SUBCASE("Insert") { + map.insert(1, 10); + map.insert(2, 20); + + CHECK(map[1] == 10); + CHECK(map[2] == 20); + } + + // Test the at() function to access elements + SUBCASE("At") { + map[1] = 10; + map[2] = 20; + + CHECK(map.at(1) == 10); + CHECK(map.at(2) == 20); + CHECK(map.at(1) != 20); + // Test const version of at() function + stack_map const &const_map = map; + CHECK(const_map.at(1) == 10); + CHECK(const_map.at(2) == 20); + } + + // Test the begin() and end() functions for iterator + SUBCASE("Iterator") { + map[1] = 10; + map[2] = 20; + map[3] = 30; + + std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; + std::vector> actual = map; + CHECK(actual == expected); + } } } diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/test_stack_string.cc index 700b7d6a0f..1836e0824a 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/test_stack_string.cc @@ -3,79 +3,81 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("StackStringConstruction", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("StackStringConstruction", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - SUBCASE("DefaultConstruction") { - StackString str; - CHECK(str.size() == 0); - CHECK(str.length() == 0); - CHECK(static_cast(str) == ""); - } + SUBCASE("DefaultConstruction") { + StackString str; + CHECK(str.size() == 0); + CHECK(str.length() == 0); + CHECK(static_cast(str) == ""); + } - SUBCASE("CStringConstruction") { - char const *cstr = "Hello"; - StackString str(cstr); - CHECK(str.size() == 5); - CHECK(str.length() == 5); - CHECK(static_cast(str) == "Hello"); - } + SUBCASE("CStringConstruction") { + char const *cstr = "Hello"; + StackString str(cstr); + CHECK(str.size() == 5); + CHECK(str.length() == 5); + CHECK(static_cast(str) == "Hello"); + } - SUBCASE("ShortCStringConstruction") { - char const *cstr = "CMU"; - StackString str(cstr); - CHECK(str.size() == 3); - CHECK(str.length() == 3); - CHECK(static_cast(str) == "CMU"); - } + SUBCASE("ShortCStringConstruction") { + char const *cstr = "CMU"; + StackString str(cstr); + CHECK(str.size() == 3); + CHECK(str.length() == 3); + CHECK(static_cast(str) == "CMU"); + } - SUBCASE("StdStringConstruction") { - std::basic_string stdStr = "World"; - StackString str(stdStr); - CHECK(str.size() == 5); - CHECK(str.length() == 5); - CHECK(static_cast(str) == "World"); + SUBCASE("StdStringConstruction") { + std::basic_string stdStr = "World"; + StackString str(stdStr); + CHECK(str.size() == 5); + CHECK(str.length() == 5); + CHECK(static_cast(str) == "World"); + } } -} -TEST_CASE_TEMPLATE("StackStringComparison", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringComparison", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - StackString str1{"abc"}; - StackString str2{"def"}; - StackString str3{"abc"}; + StackString str1{"abc"}; + StackString str2{"def"}; + StackString str3{"abc"}; - CHECK(str1 == str1); - CHECK(str1 == str3); - CHECK(str1 != str2); - CHECK(str2 != str3); - CHECK(str1 < str2); -} + CHECK(str1 == str1); + CHECK(str1 == str3); + CHECK(str1 != str2); + CHECK(str2 != str3); + CHECK(str1 < str2); + } -TEST_CASE_TEMPLATE("StackStringSize", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringSize", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - SUBCASE("EmptyString") { - StackString str; - CHECK(str.size() == 0); - CHECK(str.length() == 0); - } + SUBCASE("EmptyString") { + StackString str; + CHECK(str.size() == 0); + CHECK(str.length() == 0); + } - SUBCASE("NonEmptyString") { - StackString str{"Hello"}; - CHECK(str.size() == 5); - CHECK(str.length() == 5); + SUBCASE("NonEmptyString") { + StackString str{"Hello"}; + CHECK(str.size() == 5); + CHECK(str.length() == 5); + } } -} -TEST_CASE_TEMPLATE("StackStringConversion", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringConversion", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - StackString str{"Hello"}; - std::string stdStr = static_cast(str); - CHECK(stdStr == "Hello"); + StackString str{"Hello"}; + std::string stdStr = static_cast(str); + CHECK(stdStr == "Hello"); + } } diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 08101527f9..6c0ecf36f3 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -4,74 +4,76 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - std::vector res = vector; - std::vector expected = {10}; - CHECK(res == expected); - - vector.push_back(20); - expected = {10, 20}; - res = vector; - CHECK(res == expected); -} - -TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - vector.push_back(20); - vector.push_back(30); - - CHECK(vector[0] == 10); - CHECK(vector[1] == 20); - CHECK(vector[2] == 30); -} - -TEST_CASE_TEMPLATE("Size", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - CHECK(vector.size() == 0); - - vector.push_back(10); - CHECK(vector.size() == 1); - - vector.push_back(20); - CHECK(vector.size() == 2); -} - -TEST_CASE_TEMPLATE("==", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector1, vector2; - - vector1.push_back(10); - vector1.push_back(15); - vector1.push_back(20); - - vector2.push_back(10); - vector2.push_back(15); - vector2.push_back(20); - - CHECK(vector1 == vector2); -} - -TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - CHECK(vector.back() == 10); - - vector.push_back(20); - CHECK(vector.back() == 20); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + std::vector res = vector; + std::vector expected = {10}; + CHECK(res == expected); + + vector.push_back(20); + expected = {10, 20}; + res = vector; + CHECK(res == expected); + } + + TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + vector.push_back(20); + vector.push_back(30); + + CHECK(vector[0] == 10); + CHECK(vector[1] == 20); + CHECK(vector[2] == 30); + } + + TEST_CASE_TEMPLATE("Size", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + CHECK(vector.size() == 0); + + vector.push_back(10); + CHECK(vector.size() == 1); + + vector.push_back(20); + CHECK(vector.size() == 2); + } + + TEST_CASE_TEMPLATE("==", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector1, vector2; + + vector1.push_back(10); + vector1.push_back(15); + vector1.push_back(20); + + vector2.push_back(10); + vector2.push_back(15); + vector2.push_back(20); + + CHECK(vector1 == vector2); + } + + TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + CHECK(vector.back() == 10); + + vector.push_back(20); + CHECK(vector.back() == 20); + } } diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/test_tuple.cc index 344a2cd0fb..31308dec2c 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/test_tuple.cc @@ -6,74 +6,76 @@ using namespace FlexFlow; -TEST_CASE("get function") { - std::tuple t(42, 3.14f, 2.71828); - - SUBCASE("get mutable reference") { - int &result = get(t); - CHECK(result == 42); - - result = 100; - CHECK(std::get<0>(t) == 100); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get function") { + std::tuple t(42, 3.14f, 2.71828); + + SUBCASE("get mutable reference") { + int &result = get(t); + CHECK(result == 42); + + result = 100; + CHECK(std::get<0>(t) == 100); + } + + SUBCASE("get rvalue reference") { + int &&result = get(std::move(t)); + CHECK(result == 42); + + // t is in a valid but unspecified state after move + CHECK(std::get<0>(t) == 42); // Uncomment this line to check the behavior + } + + SUBCASE("get const reference") { + int const &result = get(t); + CHECK(result == 42); + } + + SUBCASE("get const rvalue reference") { + int const &&result = get(std::move(t)); + CHECK(result == 42); + } } - SUBCASE("get rvalue reference") { - int &&result = get(std::move(t)); - CHECK(result == 42); + TEST_CASE("tuple_prepend function") { + std::tuple t1(3.14f, 2.71828); + int value = 42; - // t is in a valid but unspecified state after move - CHECK(std::get<0>(t) == 42); // Uncomment this line to check the behavior + auto result = tuple_prepend(value, t1); + std::tuple expected(42, 3.14f, 2.71828); + CHECK(result == expected); } - SUBCASE("get const reference") { - int const &result = get(t); - CHECK(result == 42); + TEST_CASE("Testing tuple_head_t") { + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple<>>::value); } - SUBCASE("get const rvalue reference") { - int const &&result = get(std::move(t)); - CHECK(result == 42); + TEST_CASE("Testing tuple_slice_t") { + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple>::value); } -} - -TEST_CASE("tuple_prepend function") { - std::tuple t1(3.14f, 2.71828); - int value = 42; - auto result = tuple_prepend(value, t1); - std::tuple expected(42, 3.14f, 2.71828); - CHECK(result == expected); -} - -TEST_CASE("Testing tuple_head_t") { - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple<>>::value); -} + TEST_CASE("Testing tuple_compare function") { + std::tuple tup1{1, 3.14, 'a'}; + std::tuple tup2{1, 3.14, 'a'}; + std::tuple tup3{2, 3.14, 'b'}; -TEST_CASE("Testing tuple_slice_t") { - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple>::value); -} - -TEST_CASE("Testing tuple_compare function") { - std::tuple tup1{1, 3.14, 'a'}; - std::tuple tup2{1, 3.14, 'a'}; - std::tuple tup3{2, 3.14, 'b'}; - - CHECK(tuple_compare(tup1, tup2)); - CHECK(!tuple_compare(tup1, tup3)); -} + CHECK(tuple_compare(tup1, tup2)); + CHECK(!tuple_compare(tup1, tup3)); + } -TEST_CASE("Testing get function with valid index") { - std::tuple tup{1, 3.14, 'a'}; + TEST_CASE("Testing get function with valid index") { + std::tuple tup{1, 3.14, 'a'}; - CHECK(get(tup) == 1); - CHECK(get(tup) == 3.14); - CHECK(get(tup) == 'a'); + CHECK(get(tup) == 1); + CHECK(get(tup) == 3.14); + CHECK(get(tup) == 'a'); + } } diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc index 1b9a811846..b2d8aea848 100644 --- a/lib/utils/test/src/test_type_index.cc +++ b/lib/utils/test/src/test_type_index.cc @@ -4,30 +4,32 @@ using namespace FlexFlow; -TEST_CASE("type_index function") { - SUBCASE("int type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(int); - CHECK(idx == expected_idx); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("type_index function") { + SUBCASE("int type") { + std::type_index idx = type_index(); + std::type_index expected_idx = typeid(int); + CHECK(idx == expected_idx); + } - SUBCASE("string type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(std::string); - CHECK(idx == expected_idx); + SUBCASE("string type") { + std::type_index idx = type_index(); + std::type_index expected_idx = typeid(std::string); + CHECK(idx == expected_idx); + } } -} -TEST_CASE("matches function") { - std::type_index idx = typeid(float); + TEST_CASE("matches function") { + std::type_index idx = typeid(float); - SUBCASE("matching type") { - bool result = matches(idx); - CHECK(result == true); - } + SUBCASE("matching type") { + bool result = matches(idx); + CHECK(result == true); + } - SUBCASE("non-matching type") { - bool result = matches(idx); - CHECK(result == false); + SUBCASE("non-matching type") { + bool result = matches(idx); + CHECK(result == false); + } } } diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index c6f2003ee4..a60a330ad3 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -31,30 +31,32 @@ using namespace rc; /* static_assert(is_streamable::value, ""); */ /* static_assert(is_fmtable::value, ""); */ -TEST_CASE_TEMPLATE("UndirectedGraph implementations", - T, - HashmapUndirectedGraph) { - - rc::dc_check("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct(gen::elementOf(n), gen::elementOf(n))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == without_order(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(UndirectedEdgeQuery::all()) == without_order(e)); - }); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("UndirectedGraph implementations", + T, + HashmapUndirectedGraph) { + + rc::dc_check("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct(gen::elementOf(n), gen::elementOf(n))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(NodeQuery::all()) == without_order(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(UndirectedEdgeQuery::all()) == without_order(e)); + }); + } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 541ff40920..f7d08889de 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,69 +1,71 @@ #include "test/utils/doctest.h" #include "utils/variant.h" -TEST_CASE("widen and narrow functions") { - SUBCASE("widen function") { - std::variant v1 = 42; - std::variant result = - widen>(v1); - std::variant expected = 42; - CHECK(result == expected); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("widen and narrow functions") { + SUBCASE("widen function") { + std::variant v1 = 42; + std::variant result = + widen>(v1); + std::variant expected = 42; + CHECK(result == expected); + } - SUBCASE("narrow function fail") { - std::variant v2 = - 3.14; // this is a doule, because 3.14 default to double - std::optional> result = - narrow>(v2); - std::optional> expected = float(3.14); - CHECK(!result.has_value()); // result should be empty due to narrowing - } + SUBCASE("narrow function fail") { + std::variant v2 = + 3.14; // this is a doule, because 3.14 default to double + std::optional> result = + narrow>(v2); + std::optional> expected = float(3.14); + CHECK(!result.has_value()); // result should be empty due to narrowing + } - SUBCASE("narrow function success") { - std::variant v2 = - 3.14; // this is a doule, because 3.14 default to double - std::optional> result = - narrow>(v2); - std::optional> expected = 3.14; - CHECK(result == expected); // - } + SUBCASE("narrow function success") { + std::variant v2 = + 3.14; // this is a doule, because 3.14 default to double + std::optional> result = + narrow>(v2); + std::optional> expected = 3.14; + CHECK(result == expected); // + } - SUBCASE("cast function") { - std::variant v3 = 42; - std::optional> result = - cast>(v3); - std::optional> expected = 42; - CHECK(result == expected); + SUBCASE("cast function") { + std::variant v3 = 42; + std::optional> result = + cast>(v3); + std::optional> expected = 42; + CHECK(result == expected); + } } -} -TEST_CASE("Narrow and cast variants") { - std::variant original_variant = 42; + TEST_CASE("Narrow and cast variants") { + std::variant original_variant = 42; - // narrow - std::optional> narrow_result = - narrow>(original_variant); - CHECK(narrow_result.has_value()); // assert narrow has value + // narrow + std::optional> narrow_result = + narrow>(original_variant); + CHECK(narrow_result.has_value()); // assert narrow has value - // cast - std::optional> cast_result = - cast>(narrow_result.value()); - CHECK(cast_result.has_value()); // assert cast has value - CHECK(get(cast_result.value()) == 42); -} + // cast + std::optional> cast_result = + cast>(narrow_result.value()); + CHECK(cast_result.has_value()); // assert cast has value + CHECK(get(cast_result.value()) == 42); + } -TEST_CASE("casting and widening a variant") { - std::variant smaller_variant = 42; - std::variant wider_variant; + TEST_CASE("casting and widening a variant") { + std::variant smaller_variant = 42; + std::variant wider_variant; - // Perform the cast operation - std::optional> cast_result = - cast>(smaller_variant); - REQUIRE(cast_result); // Ensure the cast was successful + // Perform the cast operation + std::optional> cast_result = + cast>(smaller_variant); + REQUIRE(cast_result); // Ensure the cast was successful - // Perform the widening operation - wider_variant = widen>(cast_result.value()); + // Perform the widening operation + wider_variant = widen>(cast_result.value()); - // Check the result - CHECK(get(wider_variant) == 42); + // Check the result + CHECK(get(wider_variant) == 42); + } } diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/test_vector.cc index 5eba16c312..4bdc724dd8 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/test_vector.cc @@ -1,29 +1,31 @@ #include "test/utils/doctest.h" #include "utils/vector.h" -TEST_CASE("concat function") { - SUBCASE("concatenates two vectors") { - std::vector v1 = {1, 2, 3}; - std::vector v2 = {4, 5, 6}; - std::vector result = concat(v1, v2); - std::vector expected = {1, 2, 3, 4, 5, 6}; - CHECK(result == expected); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat function") { + SUBCASE("concatenates two vectors") { + std::vector v1 = {1, 2, 3}; + std::vector v2 = {4, 5, 6}; + std::vector result = concat(v1, v2); + std::vector expected = {1, 2, 3, 4, 5, 6}; + CHECK(result == expected); + } - SUBCASE("concatenates two string vectors") { - std::vector v1 = {"1", "2", "3"}; - std::vector v2 = {"4", "5", "6"}; - std::vector result = concat(v1, v2); - std::vector expected = {"1", "2", "3", "4", "5", "6"}; - CHECK(result == expected); - } + SUBCASE("concatenates two string vectors") { + std::vector v1 = {"1", "2", "3"}; + std::vector v2 = {"4", "5", "6"}; + std::vector result = concat(v1, v2); + std::vector expected = {"1", "2", "3", "4", "5", "6"}; + CHECK(result == expected); + } - SUBCASE("concatenates multiple vectors") { - std::vector v1 = {1, 2, 3}; - std::vector v2 = {4, 5, 6}; - std::vector v3 = {7, 8, 9}; - std::vector result = concat(v1, v2, v3); - std::vector expected = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - CHECK(result == expected); + SUBCASE("concatenates multiple vectors") { + std::vector v1 = {1, 2, 3}; + std::vector v2 = {4, 5, 6}; + std::vector v3 = {7, 8, 9}; + std::vector result = concat(v1, v2, v3); + std::vector expected = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + CHECK(result == expected); + } } } From da7481790b5cfdb9a2ed4b30f0840fb7bdbef97f Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 26 Mar 2024 16:39:53 -0700 Subject: [PATCH 28/37] Remove unnecessary nix files, add utils test to ci --- .flake/pkgs/fmt.nix | 73 ----------------------------- .flake/pkgs/rapidcheck.nix | 48 ------------------- .flake/pkgs/tokenizers-cpp.nix | 43 ----------------- .github/workflows/per-lib-check.yml | 4 ++ 4 files changed, 4 insertions(+), 164 deletions(-) delete mode 100644 .flake/pkgs/fmt.nix delete mode 100644 .flake/pkgs/rapidcheck.nix delete mode 100644 .flake/pkgs/tokenizers-cpp.nix diff --git a/.flake/pkgs/fmt.nix b/.flake/pkgs/fmt.nix deleted file mode 100644 index e2677bdea2..0000000000 --- a/.flake/pkgs/fmt.nix +++ /dev/null @@ -1,73 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub, fetchpatch -, cmake -, enableShared ? !stdenv.hostPlatform.isStatic - -# tests -, mpd -, openimageio -, fcitx5 -, spdlog -}: - -let - generic = { version, sha256, patches ? [ ] }: - stdenv.mkDerivation { - pname = "fmt"; - inherit version; - - outputs = [ "out" "dev" ]; - - src = fetchFromGitHub { - owner = "fmtlib"; - repo = "fmt"; - rev = version; - inherit sha256; - }; - - inherit patches; - - nativeBuildInputs = [ cmake ]; - - cmakeFlags = [ - "-DBUILD_SHARED_LIBS=${if enableShared then "ON" else "OFF"}" - ]; - - doCheck = true; - - passthru.tests = { - inherit mpd openimageio fcitx5 spdlog; - }; - - meta = with lib; { - description = "Small, safe and fast formatting library"; - longDescription = '' - fmt (formerly cppformat) is an open-source formatting library. It can be - used as a fast and safe alternative to printf and IOStreams. - ''; - homepage = "https://fmt.dev/"; - changelog = "https://github.com/fmtlib/fmt/blob/${version}/ChangeLog.rst"; - downloadPage = "https://github.com/fmtlib/fmt/"; - maintainers = [ maintainers.jdehaas ]; - license = licenses.mit; - platforms = platforms.all; - }; - }; -in -{ - fmt_8 = generic { - version = "8.1.1"; - sha256 = "sha256-leb2800CwdZMJRWF5b1Y9ocK0jXpOX/nwo95icDf308="; - }; - - fmt_9 = generic { - version = "9.1.0"; - sha256 = "sha256-rP6ymyRc7LnKxUXwPpzhHOQvpJkpnRFOt2ctvUNlYI0="; - }; - - fmt_10 = generic { - version = "10.1.1"; - sha256 = "sha256-H9+1lEaHM12nzXSmo9m8S6527t+97e6necayyjCPm1A="; - }; -} diff --git a/.flake/pkgs/rapidcheck.nix b/.flake/pkgs/rapidcheck.nix deleted file mode 100644 index 3ff63207b2..0000000000 --- a/.flake/pkgs/rapidcheck.nix +++ /dev/null @@ -1,48 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub -, cmake -, unstableGitUpdater -, testers -}: - -stdenv.mkDerivation (finalAttrs: { - pname = "rapidcheck"; - version = "unstable-2023-12-14"; - - src = fetchFromGitHub { - owner = "emil-e"; - repo = "rapidcheck"; - rev = "ff6af6fc683159deb51c543b065eba14dfcf329b"; - hash = "sha256-Ixz5RpY0n8Un/Pv4XoTfbs40+70iyMbkQUjDqoLaWOg="; - }; - - nativeBuildInputs = [ cmake ]; - - cmakeFlags = [ - (lib.cmakeBool "BUILD_SHARED_LIBS" (!stdenv.hostPlatform.isStatic)) - (lib.cmakeBool "RC_INSTALL_ALL_EXTRAS" true) - ]; - - passthru = { - updateScript = unstableGitUpdater { }; - tests.pkg-config = testers.testMetaPkgConfig finalAttrs.finalPackage; - }; - - meta = with lib; { - description = "A C++ framework for property based testing inspired by QuickCheck"; - inherit (finalAttrs.src.meta) homepage; - maintainers = with maintainers; [ ]; - license = licenses.bsd2; - pkgConfigModules = [ - "rapidcheck" - # Extras - "rapidcheck_boost" - "rapidcheck_boost_test" - "rapidcheck_catch" - "rapidcheck_doctest" - "rapidcheck_gtest" - ]; - platforms = platforms.all; - }; -}) diff --git a/.flake/pkgs/tokenizers-cpp.nix b/.flake/pkgs/tokenizers-cpp.nix deleted file mode 100644 index a705667ae6..0000000000 --- a/.flake/pkgs/tokenizers-cpp.nix +++ /dev/null @@ -1,43 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub -, cmake -, rustc -, cargo -}: - -stdenv.mkDerivation rec { - pname = "tokenizers-cpp"; - version = "2024-03-13"; - - src = fetchFromGitHub { - owner = "mlc-ai"; - repo = "tokenizers-cpp"; - rev = "4f42c9fa74946d70af86671a3804b6f2433e5dac"; - sha256 = "sha256-p7OYx9RVnKUAuMexy3WjW2zyfMJ/Q9ss4xFLsbQK7wA="; - fetchSubmodules = true; - }; - - nativeBuildInputs = [ - cmake - rustc - ]; - - # cmakeFlags = [ - # "-DLegion_USE_Python=1" - # "-DLegion_BUILD_BINDINGS=1" - # "-DLegion_USE_CUDA=1" - # "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" - # ]; - - buildInputs = [ ]; - # python3 - # cudatoolkit - # ]; - - meta = with lib; { - description = "Universal cross-platform tokenizers binding to HF and sentencepiece"; - homepage = "https://github.com/mlc-ai/tokenizers-cpp"; - license = licenses.asl20; - }; -} diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index f1d069f252..874a298587 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -80,6 +80,10 @@ jobs: run: | build_libs.sh compiler + - name: Test utils + run: | + test_libs.sh utils + - name: Test substitutions run: | test_libs.sh substitutions From 0db60db6e5c6a53460a209ea72f9c70bd63caccb Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 26 Mar 2024 16:46:21 -0700 Subject: [PATCH 29/37] Fix utils tests name, format --- .../test/src/test_labelled_open_graph.cc | 6 +++-- lib/compiler/test/src/test_optimal_cost.cc | 3 ++- lib/compiler/test/src/test_unity_algorithm.cc | 2 +- .../test/src/test_pattern_matches.cc | 6 +++-- .../test/src/test_substitution.cc | 25 +++++++++++-------- lib/utils/test/CMakeLists.txt | 2 +- lib/utils/test/src/test_algorithms.cc | 8 +++--- lib/utils/test/src/test_containers.cc | 4 ++- lib/utils/test/src/test_disjoint_set.cc | 5 ++-- lib/utils/test/src/test_multidigraph.cc | 10 +++++--- lib/utils/test/src/test_undirected_graph.cc | 8 +++--- lib/utils/test/src/test_variant.cc | 3 ++- 12 files changed, 48 insertions(+), 34 deletions(-) diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc index e3498a769a..ccad7b19ff 100644 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -43,7 +43,8 @@ TEST_SUITE(FF_TEST_SUITE) { auto subgraph0 = get_subgraph(g, node_set0); auto subgraph1 = get_subgraph(g, node_set0); - auto subgraph2 = get_subgraph(g, node_set0); + auto subgraph2 = + get_subgraph(g, node_set0); auto subgraph3 = get_subgraph(g, node_set0); CHECK(bool(get_nodes(subgraph0) == node_set0)); @@ -73,7 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { split_edge(e2).second, split_edge(e3).second, e4})); CHECK(bool(get_edges(subgraph2) == std::unordered_set{e4, e5})); - CHECK(bool(get_edges(subgraph3) == std::unordered_set{e4})); + CHECK( + bool(get_edges(subgraph3) == std::unordered_set{e4})); CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index da303e3ccc..91c7a11888 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -8,7 +8,8 @@ TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck infrastructures for graphs does not work for now /* Tests whether optimal_cost can give a valid result given random PCG, trivial - allowed machine views, trivial cost estimator and random machine specification. + allowed machine views, trivial cost estimator and random machine + specification. */ // TEST_CASE("optimal_cost") { // auto test_allowed_machine_views = [](Operator const &, diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index b8fde91c51..614e9bb182 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -1,7 +1,7 @@ #include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" #include "test_cost_estimator.h" #include "test_generator.h" -#include "doctest/doctest.h" TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck does not work for now diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index f1abd5c17e..5d72bbff7e 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -102,8 +102,10 @@ TEST_SUITE(FF_TEST_SUITE) { RC_ASSERT(matches.size() == 3); for (MultiDiGraphPatternMatch const &match : matches) { - RC_ASSERT(pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), match, always_true)); + RC_ASSERT(pattern_matches(as_openmultidigraph(sg0), + as_openmultidigraph(g), + match, + always_true)); } } } diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 86ee087a29..df22d8a620 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -12,18 +12,20 @@ TEST_SUITE(FF_TEST_SUITE) { ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; ParallelTensorPattern tensor_pattern_e0{ - std::vector{TensorAttributeConstraint{ - ConstraintType::EQUAL, - ListIndexAccess{TensorAttributeKey::DIM_SIZES, 0}, - 2}}}; + std::vector{ + TensorAttributeConstraint{ConstraintType::EQUAL, + ListIndexAccess{ + TensorAttributeKey::DIM_SIZES, 0}, + 2}}}; ParallelTensorPattern tensor_pattern_empty{ std::vector{}}; - auto ig = OutputLabelledOpenMultiDiGraph:: - create>(); + auto ig = + OutputLabelledOpenMultiDiGraph:: + create>(); Node n0 = ig.add_node(operator_pattern_n0); NodePort p0 = ig.add_node_port(); InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; @@ -86,7 +88,8 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph pcg = OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); + UnorderedOutputLabelledOpenMultiDiGraph>(); Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); Node n5 = pcg.add_node(Operator{ @@ -109,8 +112,8 @@ TEST_SUITE(FF_TEST_SUITE) { }, [&](OpenMultiDiEdge const &pattern_edge, OpenMultiDiEdge const &graph_edge) { - return parallel_tensor_satisfies(pcg.at(graph_edge), - input_graph.value().at(pattern_edge)); + return parallel_tensor_satisfies( + pcg.at(graph_edge), input_graph.value().at(pattern_edge)); }}; RC_ASSERT(criterion.node_criterion(n0, n5)); diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index 97253b4ab7..40ff07285e 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -1,6 +1,6 @@ ff_add_test_executable( NAME - utils-test + utils-tests SRC_PATTERNS src/test_cow_ptr.cc PRIVATE_INCLUDE diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index d3236a7b1c..0fb258bf15 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -109,7 +109,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("traversal") { DiGraph g = DiGraph::create(); std::vector const n = add_nodes(g, 5); - std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; + std::vector edges = { + {n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; add_edges(g, edges); CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); @@ -138,7 +139,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nonlinear") { g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs } SUBCASE("not connected") { @@ -168,7 +169,8 @@ TEST_SUITE(FF_TEST_SUITE) { auto CHECK_BEFORE = [&](int l, int r) { CHECK(index_of(ordering, n[l]).has_value()); CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value()); + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); }; CHECK(ordering.size() == n.size()); diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index f6ac6e2d42..a6776d492e 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -30,7 +30,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("sum with condition") { std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; // Sum of even numbers only + auto condition = [](int x) { + return x % 2 == 0; + }; // Sum of even numbers only CHECK(sum_where(v, condition) == 6); } diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index 8bcf2e533f..80fcf87d6b 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -62,9 +62,8 @@ TEST_SUITE(FF_TEST_SUITE) { ds.get_mapping(); for (auto const &kv : mapping) { - CHECK( - *kv.second == - *expectedMapping[kv.first]); // Compare the values inside the optionals + CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values + // inside the optionals } } } diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 91631f0391..90e1bb2187 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -41,10 +41,12 @@ TEST_SUITE(FF_TEST_SUITE) { {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( - {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( - {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs( + query_set({p[1], p[2]}))) == + std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs( + query_set({p[0], p[2]}))) == + std::unordered_set{e[1], e[2]}); CHECK(g.query_edges(MultiDiEdgeQuery::all() .with_src_nodes({n[1]}) .with_dst_nodes({n[2]}) diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index a60a330ad3..3616ee59aa 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -32,9 +32,8 @@ using namespace rc; /* static_assert(is_fmtable::value, ""); */ TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("UndirectedGraph implementations", - T, - HashmapUndirectedGraph) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { rc::dc_check("Full", [&]() { UndirectedGraph g = UndirectedGraph::create(); @@ -45,7 +44,8 @@ TEST_SUITE(FF_TEST_SUITE) { if (num_nodes > 0) { e = *gen::unique>( num_edges, - gen::construct(gen::elementOf(n), gen::elementOf(n))); + gen::construct(gen::elementOf(n), + gen::elementOf(n))); } for (UndirectedEdge const &edge : e) { g.add_edge(edge); diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index f7d08889de..0fef782c0e 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -63,7 +63,8 @@ TEST_SUITE(FF_TEST_SUITE) { REQUIRE(cast_result); // Ensure the cast was successful // Perform the widening operation - wider_variant = widen>(cast_result.value()); + wider_variant = + widen>(cast_result.value()); // Check the result CHECK(get(wider_variant) == 42); From dac253d47e96081c9b90b14161e92eb70eb2c424 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 15 Apr 2024 01:03:55 -0700 Subject: [PATCH 30/37] Documentation for graph library --- lib/utils/include/utils/graph/README.md | 38 +++- .../utils/graph/docs/generate_diagram.py | 99 +++++++++ .../utils/graph/docs/graph_classes.puml | 193 ++++++++++++++++++ 3 files changed, 326 insertions(+), 4 deletions(-) create mode 100644 lib/utils/include/utils/graph/docs/generate_diagram.py create mode 100644 lib/utils/include/utils/graph/docs/graph_classes.puml diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index c62b2df294..4a11669b16 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -4,8 +4,8 @@ FlexFlow's graph library very intentionally attempts to balance performance and ease of use. The graph library aims to have a very simple external interface that is highly decoupled from the underlying representations, so performance and internal implementations can be tuned and modified over time without breaking the code that uses the library. -Because FlexFlow's graphs are not on the scale of machine memory or not so large that single traversals takes nontrivial time, the graph library intentially avoids performance opportunites that would expose many of these performance aspects to user code. -Of course, there are also some optimizations that simply have not been done due to time constraints: for example, algorithms currently are able to be specialized for the underlyign representation being used, but this could be added without modifying the user-side interface. +Because FlexFlow's graphs are not on the scale of machine memory or not so large that single traversals takes nontrivial time, the graph library intentionally avoids performance opportunities that would expose many of these performance aspects to user code. +Of course, there are also some optimizations that simply have not been done due to time constraints: for example, algorithms currently are able to be specialized for the underlying representation being used, but this could be added without modifying the user-side interface. ## Usage @@ -17,7 +17,7 @@ At their core, they are as follows: - `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected - `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination` indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. +- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. Examples of the different graph variants are shown below. @@ -149,6 +149,7 @@ To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. `MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. +`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -179,6 +180,12 @@ Generally users will use underlying representations provided by the graph librar [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ +### Open, Upward, Downward + +`Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. +We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). + + ### Labelled Graphs As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. @@ -193,4 +200,27 @@ As such, the labelled graph types provide the typical `at` method (as on `std::u ## Internals -TODO @lockshaw +Most of the major graph classes in the library come in sets of 4 (example considering `ClassName`) +- `ClassName` +- `ClassNameView` +- `IClassName` +- `IClassNameView` + +The rationale behind the `View` variants has been explained in previous sections. + +The rationale for the `I(nterface)` variations is derived from the way that C++ models polymorphism. +Inheritance within the library is almost exclusively virtual: such inheritance model is demanded by the nested inheritance structure. +In the case of a diamond inheritance pattern C++, unlike languages such as Python, will instantiate multiple copies of the base class whenever we instantiate a derived class. +To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. + +Furthermore, the use of virtual functions allows for runtime polymorphism, allowing for a single function defined on some superclass to also work correctly on it's subclasses. + +C++ polymorphism is normally achieved with the following pattern: + +`std::shared_ptr = new DerivedClass();` + +This pattern however leaves the burden of memory management on the user. +To address this, graph classes within the library store as a member +### strong_typedef +`Node` inherits from `strong_typedef`: this is in order to ensure that distinct types that alias the same type are still considered distinct (and thus using one in place of the other will result in a compiler error). +For more info, see https://www.foonathan.net/2016/10/strong-typedefs/ diff --git a/lib/utils/include/utils/graph/docs/generate_diagram.py b/lib/utils/include/utils/graph/docs/generate_diagram.py new file mode 100644 index 0000000000..dfb6d6cad0 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/generate_diagram.py @@ -0,0 +1,99 @@ +'''Script to generate a PlantUML graph for the inheritance / dependency hierarchy between the graph classes''' + +import subprocess +import re +from dataclasses import dataclass +from collections import defaultdict + +@dataclass +class Component: + name: str + rawstring: str + +def clean_puml(puml : bytes) -> str: + puml = puml.decode().split('\n') + puml = filter(lambda string : all(not string.strip(' \t').startswith(char) for char in '+-#'), puml) #remove info related to class members + puml = (line.strip('\t') for line in puml) + puml = '\n'.join(puml) + puml = puml.replace(" {\n}", '') + return puml + +def remove_enum(puml): + return puml.replace('\nenum LRDirection {\nLEFT\nRIGHT\n}\n', '') + + +def remove_namespace(puml): + pattern = r'namespace FlexFlow {([^}]*)}' + puml = re.sub(pattern, lambda x: x.group(1).strip(), puml, flags=re.DOTALL) + puml = puml.replace('FlexFlow.', '') + return puml + +def get_components(puml): + components = [] + for line in puml.split('\n'): + if 'class' in line: + name = re.sub(r'\b(?:class|abstract\s+class)\b ', '', line) + components.append(Component(name, line)) + return components + +def get_additional_cowptr_connections(components): + extra_connections = [] + names = {c.name for c in components} + for name in names: + if 'I'+name in names: + extra_connections.append(f'I{name} *-- {name}') + return extra_connections + +def get_connections(puml, includeaggregation=False): + pattern = '--' if includeaggregation else '<|--' + connections = [] + for line in puml.split('\n'): + if pattern in line: + connections.append(line) + return connections + +def classify_component(name): + if name.endswith('Query'): + return 'Query' + if 'Labelled' in name: + return 'Labelled' + if 'Node' in name: + return 'Node' + if any(pattern in name for pattern in ('Edge', 'Input', 'Output')): + return 'Edge' + if name.endswith('Graph'): + if name.endswith('MultiDiGraph'): return 'Graph.MultiDiGraph_' + if name.endswith('UndirectedGraph'): return 'Graph.UndirectedGraph_' + return 'Graph.BasicGraph' + if name.endswith('View'): + if name.endswith('MultiDiGraphView'): return 'View.MultiDiGraphView_' + if name.endswith('SubgraphView'): return 'View.SubgraphView_' + return 'View.BasicView' + return 'Other' + +if __name__=='__main__': + cmd = 'hpp2plantuml -i "../*.h"' + puml : bytes = subprocess.check_output(cmd, shell=True) + print(puml) + puml = clean_puml(puml) + puml = remove_enum(puml) + puml = remove_namespace(puml) + + components = get_components(puml) + connections = get_connections(puml) + cowptr_connections = get_additional_cowptr_connections(components) + connections += cowptr_connections + packages = defaultdict(list) + for component in components: + packages[classify_component(component.name)].append(component) + + final_puml = "" + final_puml += "@startuml\n\n" + for packagename, components in packages.items(): + component_string = '\n'.join(f'\t{c.rawstring}' for c in components) + final_puml+=f'package {packagename} {{ \n{component_string} \n}}\n\n' + + final_puml+='\n'.join(connections) + final_puml+="\n\n@enduml" + with open('output.puml', 'w') as file: + file.write(final_puml) diff --git a/lib/utils/include/utils/graph/docs/graph_classes.puml b/lib/utils/include/utils/graph/docs/graph_classes.puml new file mode 100644 index 0000000000..18825ebe45 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/graph_classes.puml @@ -0,0 +1,193 @@ +@startuml + +package Graph.BasicGraph { + class AdjacencyDiGraph + class DiGraph + class Graph + abstract class IDiGraph + abstract class IGraph + class ViewUndirectedGraphAsDiGraph +} + +package Edge { + class AdjacencyInputEdges + class AdjacencyOutputEdges + class AddDirectedEdgesView + class DiInput + class DiOutput + class DirectedEdge + class InputMultiDiEdge + class MultiDiEdge + class MultiDiInput + class MultiDiOutput + class OutputMultiDiEdge + class UndirectedEdge +} + +package Graph.MultiDiGraph_ { + class AdjacencyMultiDiGraph + class AdjacencyOpenMultiDiGraph + class DownwardOpenMultiDiGraph + abstract class IDownwardOpenMultiDiGraph + abstract class IMultiDiGraph + abstract class IOpenMultiDiGraph + abstract class IUpwardOpenMultiDiGraph + class MultiDiGraph + class OpenMultiDiGraph + class UpwardOpenMultiDiGraph + class ViewDiGraphAsMultiDiGraph + class ViewMultiDiGraphAsOpenMultiDiGraph +} + +package Graph.UndirectedGraph_ { + class HashmapUndirectedGraph + abstract class IUndirectedGraph + class UndirectedGraph + class ViewDiGraphAsUndirectedGraph +} + +package View.BasicView { + class BFSView + class CheckedDFSView + class DiGraphView + class FlippedView + class GraphView + abstract class IDiGraphView + abstract class IGraphView + abstract class IUndirectedGraphView + class JoinedDigraphView + class JoinedMultiDigraphView + class JoinedUndirectedGraphView + class UncheckedDFSView + class UndirectedGraphView +} + +package View.SubgraphView_ { + class ClosedMultiDiSubgraphView + class DiSubgraphView + class DownwardOpenMultiDiSubgraphView + class MultiDiSubgraphView + class OpenMultiDiSubgraphView + class UndirectedSubgraphView + class UpwardOpenMultiDiSubgraphView +} + +package Node { + class ContractNodeView + class GetDstNodeFunctor + class GetSrcNodeFunctor + class JoinNodeKey + class JoinedNodeView + class Node + class NodePort + class NodeSource + class SingleSourceNodeView +} + +package Query { + class DirectedEdgeQuery + class DownwardOpenMultiDiEdgeQuery + class InputMultiDiEdgeQuery + class MultiDiEdgeQuery + class NodeQuery + class OpenMultiDiEdgeQuery + class OutputMultiDiEdgeQuery + class UndirectedEdgeQuery + class UpwardOpenMultiDiEdgeQuery +} + +package View.MultiDiGraphView_ { + class DownwardOpenMultiDiGraphView + abstract class IDownwardOpenMultiDiGraphView + abstract class IMultiDiGraphView + abstract class IOpenMultiDiGraphView + abstract class IUpwardOpenMultiDiGraphView + class MultiDiGraphView + class OpenMultiDiGraphView + class UpwardOpenMultiDiGraphView +} + +package Other { + class GetDstIdxFunctor + class GetSrcIdxFunctor + class Parallel + class Serial + class bfs_iterator + class checked_dfs_iterator + class cow_ptr_t > + class query_set > + class unchecked_dfs_iterator +} + +DiGraphView <|-- DiGraph +DiGraphView <|-- MultiDiGraphView +DiInput <|-- DirectedEdge +DiInput <|-- MultiDiInput +DiOutput <|-- DirectedEdge +DiOutput <|-- MultiDiOutput +DownwardOpenMultiDiGraphView <|-- DownwardOpenMultiDiGraph +GraphView <|-- DiGraphView +GraphView <|-- Graph +GraphView <|-- UndirectedGraphView +IDiGraph <|-- AdjacencyDiGraph +IDiGraphView <|-- AddDirectedEdgesView +IDiGraphView <|-- ContractNodeView +IDiGraphView <|-- DiSubgraphView +IDiGraphView <|-- FlippedView +IDiGraphView <|-- IDiGraph +IDiGraphView <|-- IMultiDiGraphView +IDiGraphView <|-- JoinedDigraphView +IDiGraphView <|-- SingleSourceNodeView +IDiGraphView <|-- ViewUndirectedGraphAsDiGraph +IDownwardOpenMultiDiGraphView <|-- IDownwardOpenMultiDiGraph +IGraphView <|-- IDiGraphView +IGraphView <|-- IGraph +IGraphView <|-- IUndirectedGraphView +IMultiDiGraph <|-- AdjacencyMultiDiGraph +IMultiDiGraphView <|-- IMultiDiGraph +IMultiDiGraphView <|-- IOpenMultiDiGraphView +IMultiDiGraphView <|-- JoinedMultiDigraphView +IMultiDiGraphView <|-- MultiDiSubgraphView +IMultiDiGraphView <|-- ViewDiGraphAsMultiDiGraph +IOpenMultiDiGraph <|-- AdjacencyOpenMultiDiGraph +IOpenMultiDiGraphView <|-- ClosedMultiDiSubgraphView +IOpenMultiDiGraphView <|-- DownwardOpenMultiDiSubgraphView +IOpenMultiDiGraphView <|-- IDownwardOpenMultiDiGraphView +IOpenMultiDiGraphView <|-- IOpenMultiDiGraph +IOpenMultiDiGraphView <|-- IUpwardOpenMultiDiGraphView +IOpenMultiDiGraphView <|-- OpenMultiDiSubgraphView +IOpenMultiDiGraphView <|-- UpwardOpenMultiDiSubgraphView +IOpenMultiDiGraphView <|-- ViewMultiDiGraphAsOpenMultiDiGraph +IUndirectedGraph <|-- HashmapUndirectedGraph +IUndirectedGraphView <|-- IUndirectedGraph +IUndirectedGraphView <|-- JoinedUndirectedGraphView +IUndirectedGraphView <|-- UndirectedSubgraphView +IUndirectedGraphView <|-- ViewDiGraphAsUndirectedGraph +IUpwardOpenMultiDiGraphView <|-- IUpwardOpenMultiDiGraph +MultiDiGraphView <|-- DownwardOpenMultiDiGraphView +MultiDiGraphView <|-- MultiDiGraph +MultiDiGraphView <|-- OpenMultiDiGraphView +MultiDiGraphView <|-- UpwardOpenMultiDiGraphView +MultiDiInput <|-- InputMultiDiEdge +MultiDiInput <|-- MultiDiEdge +MultiDiOutput <|-- MultiDiEdge +MultiDiOutput <|-- OutputMultiDiEdge +OpenMultiDiGraphView <|-- OpenMultiDiGraph +UndirectedGraphView <|-- UndirectedGraph +UpwardOpenMultiDiGraphView <|-- UpwardOpenMultiDiGraph +IDiGraph *-- DiGraph +IOpenMultiDiGraph *-- OpenMultiDiGraph +IUpwardOpenMultiDiGraphView *-- UpwardOpenMultiDiGraphView +IDownwardOpenMultiDiGraphView *-- DownwardOpenMultiDiGraphView +IUndirectedGraphView *-- UndirectedGraphView +IMultiDiGraphView *-- MultiDiGraphView +IUpwardOpenMultiDiGraph *-- UpwardOpenMultiDiGraph +IMultiDiGraph *-- MultiDiGraph +IGraphView *-- GraphView +IDownwardOpenMultiDiGraph *-- DownwardOpenMultiDiGraph +IGraph *-- Graph +IUndirectedGraph *-- UndirectedGraph +IDiGraphView *-- DiGraphView +IOpenMultiDiGraphView *-- OpenMultiDiGraphView + +@enduml \ No newline at end of file From 7a1213b105a07b9f5204003d80d420a62e233f9b Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 15 Apr 2024 01:14:55 -0700 Subject: [PATCH 31/37] Minor changes --- lib/utils/include/utils/graph/README.md | 2 +- lib/utils/include/utils/graph/docs/generate_diagram.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 4a11669b16..79550b2c46 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -220,7 +220,7 @@ C++ polymorphism is normally achieved with the following pattern: `std::shared_ptr = new DerivedClass();` This pattern however leaves the burden of memory management on the user. -To address this, graph classes within the library store as a member +To address this, graph classes have a cow_ptr as a member (with type equal to their corresponding Interface class), to which all function calls are delegated. ### strong_typedef `Node` inherits from `strong_typedef`: this is in order to ensure that distinct types that alias the same type are still considered distinct (and thus using one in place of the other will result in a compiler error). For more info, see https://www.foonathan.net/2016/10/strong-typedefs/ diff --git a/lib/utils/include/utils/graph/docs/generate_diagram.py b/lib/utils/include/utils/graph/docs/generate_diagram.py index dfb6d6cad0..9ef5f946b6 100644 --- a/lib/utils/include/utils/graph/docs/generate_diagram.py +++ b/lib/utils/include/utils/graph/docs/generate_diagram.py @@ -1,4 +1,6 @@ -'''Script to generate a PlantUML graph for the inheritance / dependency hierarchy between the graph classes''' +''' +Script to generate a PlantUML graph for the inheritance / dependency hierarchy between the graph classes +''' import subprocess import re @@ -95,5 +97,5 @@ def classify_component(name): final_puml+='\n'.join(connections) final_puml+="\n\n@enduml" - with open('output.puml', 'w') as file: + with open('graph_diagram.puml', 'w') as file: file.write(final_puml) From fcf2d02115963d143c02e7b449d7b0ebfa0745ca Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Thu, 18 Apr 2024 21:19:32 -0700 Subject: [PATCH 32/37] Updated diagram generator script for graph docs --- .../utils/graph/docs/generate_diagram.py | 80 ++++++++++++------- 1 file changed, 51 insertions(+), 29 deletions(-) diff --git a/lib/utils/include/utils/graph/docs/generate_diagram.py b/lib/utils/include/utils/graph/docs/generate_diagram.py index 9ef5f946b6..ad9e5aeb58 100644 --- a/lib/utils/include/utils/graph/docs/generate_diagram.py +++ b/lib/utils/include/utils/graph/docs/generate_diagram.py @@ -1,6 +1,4 @@ -''' -Script to generate a PlantUML graph for the inheritance / dependency hierarchy between the graph classes -''' +'''Script to generate a PlantUML graph for the inheritance / dependency hierarchy between the graph classes''' import subprocess import re @@ -18,6 +16,7 @@ def clean_puml(puml : bytes) -> str: puml = (line.strip('\t') for line in puml) puml = '\n'.join(puml) puml = puml.replace(" {\n}", '') + puml = re.sub(r' <.*?<.*?>>', '', puml) #remove the templates return puml def remove_enum(puml): @@ -54,29 +53,50 @@ def get_connections(puml, includeaggregation=False): connections.append(line) return connections -def classify_component(name): - if name.endswith('Query'): - return 'Query' - if 'Labelled' in name: - return 'Labelled' - if 'Node' in name: - return 'Node' - if any(pattern in name for pattern in ('Edge', 'Input', 'Output')): - return 'Edge' - if name.endswith('Graph'): - if name.endswith('MultiDiGraph'): return 'Graph.MultiDiGraph_' - if name.endswith('UndirectedGraph'): return 'Graph.UndirectedGraph_' - return 'Graph.BasicGraph' - if name.endswith('View'): - if name.endswith('MultiDiGraphView'): return 'View.MultiDiGraphView_' - if name.endswith('SubgraphView'): return 'View.SubgraphView_' - return 'View.BasicView' - return 'Other' +def filter_by_groups(groups, components): + component_classifications = defaultdict(list) + filtered_components = [] + for component in components: + for packagename in groups: + filtering_func = GROUPS[packagename] + if filtering_func(component.name): + component_classifications[packagename].append(component) + filtered_components.append(component) + break + return component_classifications, filtered_components + + +def filter_connections(connections, components): + filtered_connections = [] + component_names = {comp.name for comp in components} + for conn in connections: + parent, _, child = conn.split(' ') + if parent in component_names and child in component_names: + filtered_connections.append(conn) + return filtered_connections if __name__=='__main__': - cmd = 'hpp2plantuml -i "../*.h"' + cmd = 'hpp2plantuml -i "../labelled/*.h"' puml : bytes = subprocess.check_output(cmd, shell=True) - print(puml) + + GROUPS = { + 'Graph' : lambda comp : 'Graph' in comp, + 'Edges' : lambda comp : any(comp.endswith(pattern) for pattern in ('Input', 'Output', 'Edge')), + 'Open' : lambda comp : 'Open' in comp and 'Query' not in comp, # doesn't include Upwards or Downwards + 'Open.Upward' : lambda comp : 'Upward' in comp and 'Query' not in comp, + 'Open.Downward' : lambda comp : 'Downward' in comp and 'Query' not in comp, + 'DiGraphs.MultiDiGraphs' : lambda comp : 'MultiDiGraph' in comp, + 'DiGraphs' : lambda comp : 'DiGraph' in comp, + 'Undirected' : lambda comp : 'UndirectedGraph' in comp, + + 'Labelled' : lambda comp : 'Labelled' in comp, + 'Labelled.NodeLabelled' : lambda comp : 'NodeLabelled' in comp, + 'Labelled.OutputLabelled' : lambda comp : 'OutputLabelled' in comp + } + + selected_groups = ('Labelled','Labelled.NodeLabelled','Labelled.OutputLabelled') + selected_groups = sorted(selected_groups, reverse=True) #to ensure that classification for subcategories is given precedence + puml = clean_puml(puml) puml = remove_enum(puml) puml = remove_namespace(puml) @@ -85,17 +105,19 @@ def classify_component(name): connections = get_connections(puml) cowptr_connections = get_additional_cowptr_connections(components) connections += cowptr_connections - packages = defaultdict(list) - for component in components: - packages[classify_component(component.name)].append(component) + + packageclassification, components = filter_by_groups(selected_groups, components) + connections = filter_connections(connections, components) final_puml = "" - final_puml += "@startuml\n\n" - for packagename, components in packages.items(): + final_puml += "@startuml\nleft to right direction\n\n" + + for packagename, components in packageclassification.items(): component_string = '\n'.join(f'\t{c.rawstring}' for c in components) final_puml+=f'package {packagename} {{ \n{component_string} \n}}\n\n' final_puml+='\n'.join(connections) final_puml+="\n\n@enduml" - with open('graph_diagram.puml', 'w') as file: + print(final_puml) + with open('output.puml', 'w') as file: file.write(final_puml) From be171cd5c2dd41751e5428a9482adf7adaa31689 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Thu, 18 Apr 2024 21:20:43 -0700 Subject: [PATCH 33/37] Added svg files for graph documentation --- lib/utils/include/utils/graph/docs/edges.svg | 1 + lib/utils/include/utils/graph/docs/labelled.svg | 1 + lib/utils/include/utils/graph/docs/open.svg | 1 + lib/utils/include/utils/graph/docs/undirected.svg | 1 + 4 files changed, 4 insertions(+) create mode 100644 lib/utils/include/utils/graph/docs/edges.svg create mode 100644 lib/utils/include/utils/graph/docs/labelled.svg create mode 100644 lib/utils/include/utils/graph/docs/open.svg create mode 100644 lib/utils/include/utils/graph/docs/undirected.svg diff --git a/lib/utils/include/utils/graph/docs/edges.svg b/lib/utils/include/utils/graph/docs/edges.svg new file mode 100644 index 0000000000..0e01479dc2 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/edges.svg @@ -0,0 +1 @@ +EdgesDiInputDiOutputDirectedEdgeInputMultiDiEdgeMultiDiEdgeMultiDiInputMultiDiOutputOutputMultiDiEdgeUndirectedEdge \ No newline at end of file diff --git a/lib/utils/include/utils/graph/docs/labelled.svg b/lib/utils/include/utils/graph/docs/labelled.svg new file mode 100644 index 0000000000..a439c85c04 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/labelled.svg @@ -0,0 +1 @@ +LabelledNodeLabelledOutputLabelledILabelledMultiDiGraphILabelledMultiDiGraphViewLabelledMultiDiGraphLabelledMultiDiGraphViewLabelledMultiDiSubgraphViewINodeLabelledMultiDiGraphINodeLabelledMultiDiGraphViewINodeLabelledOpenMultiDiGraphINodeLabelledOpenMultiDiGraphViewNodeLabelledMultiDiGraphNodeLabelledMultiDiGraphViewNodeLabelledMultiDiSubgraphViewNodeLabelledOpenMultiDiGraphNodeLabelledOpenMultiDiGraphViewUnorderedNodeLabelledOpenMultiDiGraphIOutputLabelledMultiDiGraphIOutputLabelledMultiDiGraphViewIOutputLabelledOpenMultiDiGraphIOutputLabelledOpenMultiDiGraphViewOutputLabelledMultiDiGraphOutputLabelledMultiDiGraphViewOutputLabelledOpenMultiDiGraphOutputLabelledOpenMultiDiGraphViewOutputLabelledOpenMultiDiSubgraphViewUnorderedOutputLabelledMultiDiGraphUnorderedOutputLabelledOpenMultiDiGraphViewMultiDiGraphAsOutputLabelledViewOutputLabelledAsOutputLabelledOpen \ No newline at end of file diff --git a/lib/utils/include/utils/graph/docs/open.svg b/lib/utils/include/utils/graph/docs/open.svg new file mode 100644 index 0000000000..87766063f4 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/open.svg @@ -0,0 +1 @@ +OpenDownwardUpwardAdjacencyOpenMultiDiGraphIOpenMultiDiGraphIOpenMultiDiGraphViewOpenMultiDiGraphOpenMultiDiGraphViewOpenMultiDiSubgraphViewViewMultiDiGraphAsOpenMultiDiGraphDownwardOpenMultiDiGraphDownwardOpenMultiDiGraphViewDownwardOpenMultiDiSubgraphViewIDownwardOpenMultiDiGraphIDownwardOpenMultiDiGraphViewIUpwardOpenMultiDiGraphIUpwardOpenMultiDiGraphViewUpwardOpenMultiDiGraphUpwardOpenMultiDiGraphViewUpwardOpenMultiDiSubgraphView \ No newline at end of file diff --git a/lib/utils/include/utils/graph/docs/undirected.svg b/lib/utils/include/utils/graph/docs/undirected.svg new file mode 100644 index 0000000000..f04d893a45 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/undirected.svg @@ -0,0 +1 @@ +UndirectedHashmapUndirectedGraphIUndirectedGraphIUndirectedGraphViewJoinedUndirectedGraphViewUndirectedGraphUndirectedGraphViewViewDiGraphAsUndirectedGraphViewUndirectedGraphAsDiGraph \ No newline at end of file From 36448540957e5fc88cfd5161c21a369108c4d0b3 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Thu, 18 Apr 2024 21:34:39 -0700 Subject: [PATCH 34/37] Docs changes --- lib/utils/include/utils/graph/README.md | 6 + .../utils/graph/docs/graph_classes.puml | 193 ------------------ 2 files changed, 6 insertions(+), 193 deletions(-) delete mode 100644 lib/utils/include/utils/graph/docs/graph_classes.puml diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 79550b2c46..28493e4182 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -185,6 +185,10 @@ Generally users will use underlying representations provided by the graph librar `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). +[Open graphs inheritance diagram](docs/open.svg) + +Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info on `cow_ptr`, see below). + ### Labelled Graphs @@ -198,6 +202,8 @@ As such, the labelled graph types provide the typical `at` method (as on `std::u [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. +[Labelled Graphs Inheritance Diagram] + ## Internals Most of the major graph classes in the library come in sets of 4 (example considering `ClassName`) diff --git a/lib/utils/include/utils/graph/docs/graph_classes.puml b/lib/utils/include/utils/graph/docs/graph_classes.puml deleted file mode 100644 index 18825ebe45..0000000000 --- a/lib/utils/include/utils/graph/docs/graph_classes.puml +++ /dev/null @@ -1,193 +0,0 @@ -@startuml - -package Graph.BasicGraph { - class AdjacencyDiGraph - class DiGraph - class Graph - abstract class IDiGraph - abstract class IGraph - class ViewUndirectedGraphAsDiGraph -} - -package Edge { - class AdjacencyInputEdges - class AdjacencyOutputEdges - class AddDirectedEdgesView - class DiInput - class DiOutput - class DirectedEdge - class InputMultiDiEdge - class MultiDiEdge - class MultiDiInput - class MultiDiOutput - class OutputMultiDiEdge - class UndirectedEdge -} - -package Graph.MultiDiGraph_ { - class AdjacencyMultiDiGraph - class AdjacencyOpenMultiDiGraph - class DownwardOpenMultiDiGraph - abstract class IDownwardOpenMultiDiGraph - abstract class IMultiDiGraph - abstract class IOpenMultiDiGraph - abstract class IUpwardOpenMultiDiGraph - class MultiDiGraph - class OpenMultiDiGraph - class UpwardOpenMultiDiGraph - class ViewDiGraphAsMultiDiGraph - class ViewMultiDiGraphAsOpenMultiDiGraph -} - -package Graph.UndirectedGraph_ { - class HashmapUndirectedGraph - abstract class IUndirectedGraph - class UndirectedGraph - class ViewDiGraphAsUndirectedGraph -} - -package View.BasicView { - class BFSView - class CheckedDFSView - class DiGraphView - class FlippedView - class GraphView - abstract class IDiGraphView - abstract class IGraphView - abstract class IUndirectedGraphView - class JoinedDigraphView - class JoinedMultiDigraphView - class JoinedUndirectedGraphView - class UncheckedDFSView - class UndirectedGraphView -} - -package View.SubgraphView_ { - class ClosedMultiDiSubgraphView - class DiSubgraphView - class DownwardOpenMultiDiSubgraphView - class MultiDiSubgraphView - class OpenMultiDiSubgraphView - class UndirectedSubgraphView - class UpwardOpenMultiDiSubgraphView -} - -package Node { - class ContractNodeView - class GetDstNodeFunctor - class GetSrcNodeFunctor - class JoinNodeKey - class JoinedNodeView - class Node - class NodePort - class NodeSource - class SingleSourceNodeView -} - -package Query { - class DirectedEdgeQuery - class DownwardOpenMultiDiEdgeQuery - class InputMultiDiEdgeQuery - class MultiDiEdgeQuery - class NodeQuery - class OpenMultiDiEdgeQuery - class OutputMultiDiEdgeQuery - class UndirectedEdgeQuery - class UpwardOpenMultiDiEdgeQuery -} - -package View.MultiDiGraphView_ { - class DownwardOpenMultiDiGraphView - abstract class IDownwardOpenMultiDiGraphView - abstract class IMultiDiGraphView - abstract class IOpenMultiDiGraphView - abstract class IUpwardOpenMultiDiGraphView - class MultiDiGraphView - class OpenMultiDiGraphView - class UpwardOpenMultiDiGraphView -} - -package Other { - class GetDstIdxFunctor - class GetSrcIdxFunctor - class Parallel - class Serial - class bfs_iterator - class checked_dfs_iterator - class cow_ptr_t > - class query_set > - class unchecked_dfs_iterator -} - -DiGraphView <|-- DiGraph -DiGraphView <|-- MultiDiGraphView -DiInput <|-- DirectedEdge -DiInput <|-- MultiDiInput -DiOutput <|-- DirectedEdge -DiOutput <|-- MultiDiOutput -DownwardOpenMultiDiGraphView <|-- DownwardOpenMultiDiGraph -GraphView <|-- DiGraphView -GraphView <|-- Graph -GraphView <|-- UndirectedGraphView -IDiGraph <|-- AdjacencyDiGraph -IDiGraphView <|-- AddDirectedEdgesView -IDiGraphView <|-- ContractNodeView -IDiGraphView <|-- DiSubgraphView -IDiGraphView <|-- FlippedView -IDiGraphView <|-- IDiGraph -IDiGraphView <|-- IMultiDiGraphView -IDiGraphView <|-- JoinedDigraphView -IDiGraphView <|-- SingleSourceNodeView -IDiGraphView <|-- ViewUndirectedGraphAsDiGraph -IDownwardOpenMultiDiGraphView <|-- IDownwardOpenMultiDiGraph -IGraphView <|-- IDiGraphView -IGraphView <|-- IGraph -IGraphView <|-- IUndirectedGraphView -IMultiDiGraph <|-- AdjacencyMultiDiGraph -IMultiDiGraphView <|-- IMultiDiGraph -IMultiDiGraphView <|-- IOpenMultiDiGraphView -IMultiDiGraphView <|-- JoinedMultiDigraphView -IMultiDiGraphView <|-- MultiDiSubgraphView -IMultiDiGraphView <|-- ViewDiGraphAsMultiDiGraph -IOpenMultiDiGraph <|-- AdjacencyOpenMultiDiGraph -IOpenMultiDiGraphView <|-- ClosedMultiDiSubgraphView -IOpenMultiDiGraphView <|-- DownwardOpenMultiDiSubgraphView -IOpenMultiDiGraphView <|-- IDownwardOpenMultiDiGraphView -IOpenMultiDiGraphView <|-- IOpenMultiDiGraph -IOpenMultiDiGraphView <|-- IUpwardOpenMultiDiGraphView -IOpenMultiDiGraphView <|-- OpenMultiDiSubgraphView -IOpenMultiDiGraphView <|-- UpwardOpenMultiDiSubgraphView -IOpenMultiDiGraphView <|-- ViewMultiDiGraphAsOpenMultiDiGraph -IUndirectedGraph <|-- HashmapUndirectedGraph -IUndirectedGraphView <|-- IUndirectedGraph -IUndirectedGraphView <|-- JoinedUndirectedGraphView -IUndirectedGraphView <|-- UndirectedSubgraphView -IUndirectedGraphView <|-- ViewDiGraphAsUndirectedGraph -IUpwardOpenMultiDiGraphView <|-- IUpwardOpenMultiDiGraph -MultiDiGraphView <|-- DownwardOpenMultiDiGraphView -MultiDiGraphView <|-- MultiDiGraph -MultiDiGraphView <|-- OpenMultiDiGraphView -MultiDiGraphView <|-- UpwardOpenMultiDiGraphView -MultiDiInput <|-- InputMultiDiEdge -MultiDiInput <|-- MultiDiEdge -MultiDiOutput <|-- MultiDiEdge -MultiDiOutput <|-- OutputMultiDiEdge -OpenMultiDiGraphView <|-- OpenMultiDiGraph -UndirectedGraphView <|-- UndirectedGraph -UpwardOpenMultiDiGraphView <|-- UpwardOpenMultiDiGraph -IDiGraph *-- DiGraph -IOpenMultiDiGraph *-- OpenMultiDiGraph -IUpwardOpenMultiDiGraphView *-- UpwardOpenMultiDiGraphView -IDownwardOpenMultiDiGraphView *-- DownwardOpenMultiDiGraphView -IUndirectedGraphView *-- UndirectedGraphView -IMultiDiGraphView *-- MultiDiGraphView -IUpwardOpenMultiDiGraph *-- UpwardOpenMultiDiGraph -IMultiDiGraph *-- MultiDiGraph -IGraphView *-- GraphView -IDownwardOpenMultiDiGraph *-- DownwardOpenMultiDiGraph -IGraph *-- Graph -IUndirectedGraph *-- UndirectedGraph -IDiGraphView *-- DiGraphView -IOpenMultiDiGraphView *-- OpenMultiDiGraphView - -@enduml \ No newline at end of file From 45eeb9054da9e8baa4fb1d68b58ed05dffceed3c Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Thu, 18 Apr 2024 21:36:59 -0700 Subject: [PATCH 35/37] README change --- lib/utils/include/utils/graph/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 28493e4182..d77a363ff5 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -185,7 +185,7 @@ Generally users will use underlying representations provided by the graph librar `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). -[Open graphs inheritance diagram](docs/open.svg) +![Open graphs inheritance diagram](docs/open.svg) Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info on `cow_ptr`, see below). @@ -202,7 +202,7 @@ As such, the labelled graph types provide the typical `at` method (as on `std::u [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. -[Labelled Graphs Inheritance Diagram] +![Labelled Graphs Inheritance Diagram] ## Internals From 27d9ad6948c81a86d5dcced45a6abfd2331c9d94 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Fri, 19 Apr 2024 00:51:55 -0700 Subject: [PATCH 36/37] Updated README --- lib/utils/include/utils/graph/README.md | 52 +++++++++++++++++-------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index d77a363ff5..11d4fdde2e 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -202,31 +202,49 @@ As such, the labelled graph types provide the typical `at` method (as on `std::u [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. -![Labelled Graphs Inheritance Diagram] +![Labelled Graphs Inheritance Diagram](docs/labelled.svg) + + ## Internals -Most of the major graph classes in the library come in sets of 4 (example considering `ClassName`) -- `ClassName` -- `ClassNameView` -- `IClassName` -- `IClassNameView` +Most of the major graph classes in the library come in sets of 4. For a given class `GlassName` we have: +1. `ClassName` +2. `ClassNameView` +3. `IClassName` +4. `IClassNameView` -The rationale behind the `View` variants has been explained in previous sections. +General rules which apply to most classes: +- `ClassName` (virtually) inherits from `ClassNameView`. Similarly, `IClassName` (virtually) inherits from `IClassNameView`. +- `ClassName` has, as a member variable, a `cow_ptr` of type `IClassName`. Same holds for `ClassNameView`. +Thus, the bulk of the inheritance that actually extends functionality is present among `IClassNameView` classes. -The rationale for the `I(nterface)` variations is derived from the way that C++ models polymorphism. -Inheritance within the library is almost exclusively virtual: such inheritance model is demanded by the nested inheritance structure. -In the case of a diamond inheritance pattern C++, unlike languages such as Python, will instantiate multiple copies of the base class whenever we instantiate a derived class. -To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. -Furthermore, the use of virtual functions allows for runtime polymorphism, allowing for a single function defined on some superclass to also work correctly on it's subclasses. +### cow_ptr and Interfaces + +The reason for the existence of the `View` variants has been explained in previous sections. +The existence of the `I(nterface)` variants stems from C++'s approach to modeling polymorphism. + +C++ polymorphism is achieved through the use of [virtual functions](https://www.learncpp.com/cpp-tutorial/virtual-functions/). +To create objects with polymorphic behaviour, we use the following syntax: +`BaseClass* obj = new DerivedClass(); //or alternatives such as std::shared_ptr obj = std::make_shared();` +Any call to `obj`'s member functions are resolved at runtime (dynamic binding), with C++ calling the most derived implementation of the function. -C++ polymorphism is normally achieved with the following pattern: +While this pattern works nicely, the way instantiation is done leaves the burden of memory management on the user. +To address this, graph classes store a cow_ptr as a member variable, which point to instances of type equal to their corresponding interface class. -`std::shared_ptr = new DerivedClass();` +All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. + +To create graphs within the library, we thus use the following syntax: +`BaseGraph obj = BaseGraph::create();` + + +### Virtual Inheritance (Possibly superflous) +Due to the complexity of the graph library, diamond-style inheritance patterns emerge. +In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. +Furthermore, the use of virtual functions allows for runtime polymorphism, allowing for a single function defined on some superclass to also work correctly on it's subclasses. -This pattern however leaves the burden of memory management on the user. -To address this, graph classes have a cow_ptr as a member (with type equal to their corresponding Interface class), to which all function calls are delegated. ### strong_typedef `Node` inherits from `strong_typedef`: this is in order to ensure that distinct types that alias the same type are still considered distinct (and thus using one in place of the other will result in a compiler error). -For more info, see https://www.foonathan.net/2016/10/strong-typedefs/ +For more info, see https://www.foonathan.net/2016/10/strong-typedefs/. From 5a96eaeb4215a14e1b9b4b2c89c5f915286f4a92 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sat, 20 Apr 2024 11:52:16 -0700 Subject: [PATCH 37/37] Updated Docs --- lib/utils/include/utils/graph/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 11d4fdde2e..a9b399e155 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -238,6 +238,8 @@ All member functions present in `ClassName` and `ClassNameView` delegate their c To create graphs within the library, we thus use the following syntax: `BaseGraph obj = BaseGraph::create();` +Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` + ### Virtual Inheritance (Possibly superflous) Due to the complexity of the graph library, diamond-style inheritance patterns emerge.