From 2e98307e26e0c91e22375f650373b56e5afbcf16 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 20 Jul 2023 00:38:29 +0800 Subject: [PATCH 01/61] add unit tests for machine mapping and dp algorithm --- lib/compiler/test/test_cost_estimator.h | 23 +++++ lib/compiler/test/test_generator.h | 117 ++++++++++++++++++++++ lib/compiler/test/test_machine_mapping.cc | 38 +++++++ lib/compiler/test/test_optimal_cost.cc | 20 ++++ 4 files changed, 198 insertions(+) create mode 100644 lib/compiler/test/test_cost_estimator.h create mode 100644 lib/compiler/test/test_generator.h create mode 100644 lib/compiler/test/test_machine_mapping.cc create mode 100644 lib/compiler/test/test_optimal_cost.cc diff --git a/lib/compiler/test/test_cost_estimator.h b/lib/compiler/test/test_cost_estimator.h new file mode 100644 index 0000000000..9a4ea56156 --- /dev/null +++ b/lib/compiler/test/test_cost_estimator.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H +#define _FLEXFLOW_TEST_COST_ESTIMATOR_H + +#include "compiler/cost_estimate.h" + +namespace FlexFlow { + +struct TestCostEstimator : public ICostEstimator { + float estimate_cost(PCGOperatorAttrs const &op, + std::vector const &inputs, + MachineView const &mv) const override { + return 0.1; + } + float estimate_cost(ParallelTensorShape const &tensor_shape, + MachineView const &src, + MachineView const &dst) const override { + return 0.1; + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h new file mode 100644 index 0000000000..1942574969 --- /dev/null +++ b/lib/compiler/test/test_generator.h @@ -0,0 +1,117 @@ +#ifndef _FLEXFLOW_TEST_GENERATOR_H +#define _FLEXFLOW_TEST_GENERATOR_H + +#include "compiler/machine_mapping.h" +#include "compiler/sub_parallel_computation_graph.h" + +namespace rc { + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::construct(gen::arbitrary()); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::apply(make_1d_machine_view, + gen::construct(gen::nonZero()), + gen::construct(gen::nonZero()), + gen::inRange(1, 4)); + } +} + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&MachineMapping::runtime, gen::nonZero()); + gen::set(&MachineMapping::machine_views, + gen::container>( + gen::arbitrary(), gen::arbitrary()))); + } +} + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::apply( + [](int num_nodes, + std::vector const &lhs, + std::vector const &rhs, + std::vector const &compn_type) { + auto g = OutputLabelledMultiDiGraph::create< + UnorderedOutputLabelledMultiDiGraph>(); + + std::vector nodes, source, sink; + + for (int i = 0; i < num_nodes; ++i) { + Node new_node = g.add_node(Operator(NoopAttrs{}, nullopt)); + nodes.push_back(new_node); + source.push_back(new_node); + sink.push_back(new_node); + } + + for (int i = 0; i < lhs.size(); ++i) { + if (i >= rhs.size()) { + break; + } + if (i >= compn_type.size()) { + break; + } + + int n0 = lhs[i] % num_nodes, n1 = rhs[i] % num_nodes, + t = compn_type[i] % 2; + + Node source0 = source[n0], source1 = source[n1]; + Node sink0 = sink[n0], sink1 = sink[n1]; + + if (source0 == source1 && sink0 == sink1) { + continue; + } + + RC_ASSERT(source0 != source1 && sink0 != sink1); + + if (source0 == sink0 || t == 0) { + // sequential composition + g.add_edge(MultiDiOutput{sink0, NodePort(0)}, + MultiDiInput{source1, NodePort(0)}); + for (int j = 0; j < nodes.size(); ++j) { + if (source[j] == source1) { + source[j] = source0; + } + if (sink[j] == sink0) { + sink[j] = sink1; + } + } + } else { + // parallel composition + g.add_edge(MultiDiOutput{source0, NodePort(0)}, + MultiDiInput(source1, NodePort(0))); + g.add_edge(MultiDiOutput{ + sink1, NodePort(0), MultiDiInput(sink0, NodePort(0))}); + for (int j = 0; j < nodes.size(); ++j) { + if (source[j] == source1) { + source[j] = source0; + } + if (sink[j] == sink1) { + sink[j] = sink0; + } + } + } + } + + return ParallelComputationGraph(g); + }, + gen::inRange(1, 200), + gen::arbitrary>(), + gen::arbitrary>(), + gen::arbitrary>()); + } +} + +} // namespace rc + +#endif diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/test_machine_mapping.cc new file mode 100644 index 0000000000..f36e5781bb --- /dev/null +++ b/lib/compiler/test/test_machine_mapping.cc @@ -0,0 +1,38 @@ +#include "doctest.h" +#include "test_generator.h" + +TEST_CASE("MachineMapping::sequential_combine") { + rc::check([](MachineMapping const &mp0, MachineMapping const &mp1) { + RC_PRE(are_disjoint(keys(mp0), keys(mp1))); + + MachineMapping comb = MachineMapping::sequential_combine(mp0, mp1); + + RC_ASSERT(comb.runtime == mp0.runtime + mp1.runtime); + RC_ASSERT(comb.machine_views.size() == + mp0.machine_views.size() + mp1.machine_views.size()); + for (auto p : mp0.machine_views) { + RC_ASSERT(p.second == comb.machine_views.at(p.first)); + } + }); +} + +TEST_CASE("MachineMapping::parallel_combine") { + rc::check([](MachineMapping const &mp0, MachineMapping const &mp1) { + RC_PRE(are_disjoint(keys(mp0), keys(mp1))); + + MachineMapping comb = MachineMapping::parallel_combine(mp0, mp1); + + RC_ASSERT(comb.runtime == std::max(mp0.runtime, mp1.runtime)); + RC_ASSERT(comb.machine_views.size() == + mp0.machine_views.size() + mp1.machine_views.size()); + for (auto p : mp0.machine_views) { + RC_ASSERT(p.second == comb.machine_views.at(p.first)); + } + }); +} + +TEST_CASE("MachieMapping::infinity") { + rc::check([](MachineMapping const &mp) { + RC_ASSERT(mp.runtime <= MachineMapping::infinity().runtime); + }); +} \ No newline at end of file diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc new file mode 100644 index 0000000000..a1207ede00 --- /dev/null +++ b/lib/compiler/test/test_optimal_cost.cc @@ -0,0 +1,20 @@ +#include "test_cost_estimator.h" +#include "test_generator.h" + +TEST_CASE("optimal_cost") { + rc::check([](ParallelComputationGraph const &g) { + std::unordered_map cached_subgraph_costs; + MachineMapping machine_mapping = optimal_cost( + g, + [](Operator const &, MachineSpecification const &) { + return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + }, + TestCostEstimator{}, + MachineSpecification{1, 1, 4, 0.1, 0.2}, + cached_subgraph_costs); + RC_ASSERT(machine_mapping.runtime > 0); + for (auto node : get_nodes(g)) { + RC_ASSERT(contains_key(machine_mapping.machine_views, node)); + } + }); +} \ No newline at end of file From 8304b3d776b8bebf8b81b473b72776267c84287e Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 20 Jul 2023 00:58:24 +0800 Subject: [PATCH 02/61] add unit test for unity algorithm --- lib/compiler/test/test_generator.h | 85 ++++++++++++++++++++++- lib/compiler/test/test_unity_algorithm.cc | 25 +++++++ 2 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 lib/compiler/test/test_unity_algorithm.cc diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index 1942574969..30fcd09d46 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -3,6 +3,7 @@ #include "compiler/machine_mapping.h" #include "compiler/sub_parallel_computation_graph.h" +#include "rapidcheck.h" namespace rc { @@ -34,6 +35,84 @@ struct Arbitrary { } } +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::apply( + [](int num_nodes, + std::vector const &lhs, + std::vector const &rhs, + std::vector const &compn_type) { + auto g = + ::create>(); + + std::vector nodes, source, sink; + + for (int i = 0; i < num_nodes; ++i) { + Node new_node = g.add_node(Layer(NoopAttrs{}, nullopt)); + nodes.push_back(new_node); + source.push_back(new_node); + sink.push_back(new_node); + } + + for (int i = 0; i < lhs.size(); ++i) { + if (i >= rhs.size()) { + break; + } + if (i >= compn_type.size()) { + break; + } + + int n0 = lhs[i] % num_nodes, n1 = rhs[i] % num_nodes, + t = compn_type[i] % 2; + + Node source0 = source[n0], source1 = source[n1]; + Node sink0 = sink[n0], sink1 = sink[n1]; + + if (source0 == source1 && sink0 == sink1) { + continue; + } + + RC_ASSERT(source0 != source1 && sink0 != sink1); + + if (source0 == sink0 || t == 0) { + // sequential composition + g.add_edge(MultiDiOutput{sink0, NodePort(0)}, + MultiDiInput{source1, NodePort(0)}); + for (int j = 0; j < nodes.size(); ++j) { + if (source[j] == source1) { + source[j] = source0; + } + if (sink[j] == sink0) { + sink[j] = sink1; + } + } + } else { + // parallel composition + g.add_edge(MultiDiOutput{source0, NodePort(0)}, + MultiDiInput{source1, NodePort(0)}); + g.add_edge(MultiDiOutput{sink1, NodePort(0)}, + MultiDiInput{sink0, NodePort(0)}); + for (int j = 0; j < nodes.size(); ++j) { + if (source[j] == source1) { + source[j] = source0; + } + if (sink[j] == sink1) { + sink[j] = sink0; + } + } + } + } + + return ComputationGraph(g); + }, + gen::inRange(1, 200), + gen::arbitrary>(), + gen::arbitrary>(), + gen::arbitrary>()); + } +} + template <> struct Arbitrary { static Gen arbitrary() { @@ -89,9 +168,9 @@ struct Arbitrary { } else { // parallel composition g.add_edge(MultiDiOutput{source0, NodePort(0)}, - MultiDiInput(source1, NodePort(0))); - g.add_edge(MultiDiOutput{ - sink1, NodePort(0), MultiDiInput(sink0, NodePort(0))}); + MultiDiInput{source1, NodePort(0)}); + g.add_edge(MultiDiOutput{sink1, NodePort(0)}, + MultiDiInput{sink0, NodePort(0)}); for (int j = 0; j < nodes.size(); ++j) { if (source[j] == source1) { source[j] = source0; diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/test_unity_algorithm.cc new file mode 100644 index 0000000000..659a4d0a88 --- /dev/null +++ b/lib/compiler/test/test_unity_algorithm.cc @@ -0,0 +1,25 @@ +#include "compiler/unity_algorithm.h" +#include "test_cost_estimator.h" +#include "test_generator.h" + +TEST_CASE("graph_optimize") { + rc::check([](ComputationGraph const &g, + float alpha, + int budget, + float threshold, + int max_num_ops) { + Strategy s = graph_optimize( + g, + TestCostEstimator{}, + MachineSpecification{1, 1, 4, 0.1, 0.2}, + [](Operator const &, MachineSpecification const &) { + return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + }, + OptimizerConfig{alpha, budget, threshold, max_num_ops}); + RC_ASSERT(get_nodes(s.pcg).size() > 0); + RC_ASSERT(s.machine_mapping.runtime > 0); + for (auto node : get_nodes(s.pcg)) { + RC_ASSERT(contains_key(s.machine_mapping.machine_views, node)); + } + }); +} \ No newline at end of file From dd33af668d7e763831a9fab85d2c99931bd613aa Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 20 Jul 2023 10:50:12 +0800 Subject: [PATCH 03/61] fix compile errors from filter and support_interator_tag --- lib/utils/include/utils/containers.h | 13 ++++++++++++- lib/utils/include/utils/type_traits.h | 10 ---------- lib/utils/include/utils/type_traits_core.h | 10 ++++++++++ 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index faaac41327..25fbfddf02 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -4,7 +4,7 @@ #include "bidict.h" #include "invoke.h" #include "optional.h" -#include "type_traits.h" +#include "type_traits_core.h" #include "required_core.h" #include #include @@ -540,6 +540,17 @@ C filter(C const &v, F const &f) { return result; } +template +std::unordered_set filter(std::unordered_set const &v, F const &f) { + std::unordered_set result; + for (T const &t : v) { + if (f(t)) { + result.insert(t); + } + } + return result; +} + template void inplace_filter(C &v, F const &f) { std::remove_if(v.begin(), v.end(), [&](Elem const &e) { return !f(e); }); diff --git a/lib/utils/include/utils/type_traits.h b/lib/utils/include/utils/type_traits.h index ee44f01983..78d995a8b4 100644 --- a/lib/utils/include/utils/type_traits.h +++ b/lib/utils/include/utils/type_traits.h @@ -113,16 +113,6 @@ struct elements_satisfy> : std::true_type {}; static_assert( elements_satisfy>::value, ""); -template -struct supports_iterator_tag - : std::is_base_of::iterator_category> {}; - -#define CHECK_SUPPORTS_ITERATOR_TAG(TAG, ...) \ - static_assert(supports_iterator_tag<__VA_ARGS__, TAG>::value, \ - #__VA_ARGS__ " does not support required iterator tag " #TAG); - template using is_default_constructible = std::is_default_constructible; diff --git a/lib/utils/include/utils/type_traits_core.h b/lib/utils/include/utils/type_traits_core.h index 0798974180..acda070ad8 100644 --- a/lib/utils/include/utils/type_traits_core.h +++ b/lib/utils/include/utils/type_traits_core.h @@ -113,6 +113,16 @@ struct is_static_castable< void_t(std::declval()))>> : std::true_type { }; +template +struct supports_iterator_tag + : std::is_base_of::iterator_category> {}; + +#define CHECK_SUPPORTS_ITERATOR_TAG(TAG, ...) \ + static_assert(supports_iterator_tag<__VA_ARGS__, TAG>::value, \ + #__VA_ARGS__ " does not support required iterator tag " #TAG); + } // namespace FlexFlow #endif From 766eafa7723588278ffd117ccecc634396dad1a5 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 24 Jul 2023 20:28:59 -0400 Subject: [PATCH 04/61] minor fixes for compiler --- lib/compiler/src/graph_utils.h | 11 +++++++- lib/compiler/src/machine_mapping.cc | 27 ++++++++++--------- lib/utils/include/utils/bidict.h | 1 + .../graph/labelled/labelled_open_interfaces.h | 2 ++ 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/lib/compiler/src/graph_utils.h b/lib/compiler/src/graph_utils.h index 91132da680..88515ef950 100644 --- a/lib/compiler/src/graph_utils.h +++ b/lib/compiler/src/graph_utils.h @@ -16,7 +16,16 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); template void minimize(T &t, T const &v) { - t = std::min(t, v); + if (v < t) { + t = v; + } +} + +template +void minimize(T &t, T const &v, Compare comp) { + if (comp(v, t)) { + t = v; + } } } // namespace FlexFlow diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 57832c34b9..ef7fd0b148 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -59,11 +59,12 @@ GraphSplit } std::pair - apply_split(SubParallelComputationGraph const &g, GraphSplit const &split) { + apply_split(SubParallelComputationGraph const &g, + GraphSplit const &split) { OpenMultiDiGraphView g1 = get_subgraph(g, split.first); OpenMultiDiGraphView g2 = get_subgraph(g, split.second); - if (get_cut(g, split).size() > 0) { + if (get_edge_splits(g, split).size() > 0) { // Sequential split if (get_open_sinks(g1).size() <= get_open_sources(g2).size()) { // get_open_sinks(*g1).size() should be 1 in perfect sp graphs @@ -93,7 +94,7 @@ struct OptimalCost { optional const &source_machine_view, // assume perfect SP optional const &sink_machine_view, std::function( - PCGOperatorAttrs const &, MachineSpecification const &)> const + Operator const &, MachineSpecification const &)> const &allowed_machine_views, std::unordered_map &cached_subgraph_costs) : g(g), cost_estimator(cost_estimator), resource(resource), @@ -141,15 +142,13 @@ struct OptimalCost { } MachineMapping optimal_cost(Serial const &serial) const { - // return sum(vector_transform([&](variant const &t) { - // return visit(*this, t); }, serial.children)); auto decomposed = decompose(serial); SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; auto subgraphs = apply_split(g, get_graph_split(pre_decompn, post_decompn)); SubParallelComputationGraph pre_graph = subgraphs.first, - post_graph = subgraphs.second; + post_graph = subgraphs.second; std::unordered_set pre_graph_sinks = get_closed_sinks(pre_graph); std::unordered_set post_graph_sources = @@ -187,7 +186,8 @@ struct OptimalCost { sink_machine_view, allowed_machine_views, cached_subgraph_costs), - post_decompn))); + post_decompn)), + MachineMappingRuntimeCmp{}); } return optimal_result; @@ -237,7 +237,8 @@ struct OptimalCost { sink_machine_view, allowed_machine_views, cached_subgraph_costs), - decompn2))); + decompn2)), + MachineMappingRuntimeCmp{}); } return optimal_result; @@ -262,9 +263,9 @@ struct OptimalCost { MachineMapping optimal_result = MachineMapping::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { std::unordered_map mv_map{{node, mv}}; - minimize( - optimal_result, - {estimate_cost(g, cost_estimator, mv_map), mv_map)}; + minimize(optimal_result, + {estimate_cost(g, cost_estimator, mv_map), mv_map}, + MachineMappingRuntimeCmp{}); } return optimal_result; } @@ -276,7 +277,7 @@ struct OptimalCost { optional const &source_machine_view; optional const &sink_machine_view; std::function( - PCGOperatorAttrs const &, MachineSpecification const &)> const + Operator const &, MachineSpecification const &)> const &allowed_machine_views; std::unordered_map &cached_subgraph_costs; }; @@ -284,7 +285,7 @@ struct OptimalCost { MachineMapping optimal_cost( ParallelComputationGraph const &g, std::function( - PCGOperatorAttrs const &, MachineSpecification const &)> const + Operator const &, MachineSpecification const &)> const &allowed_machine_views, ICostEstimator const &cost_estimator, MachineSpecification const &resources, diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 870afd0448..c600f40790 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_UTILS_BIDICT_H #include +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h index a3ef390530..63af5e6d73 100644 --- a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h +++ b/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h @@ -14,6 +14,8 @@ struct ILabelledOpenMultiDiGraphView : public IOpenMultiDiGraphView, public ILabelledMultiDiGraphView { public: + using INodeLabelledMultiDiGraphView::at; + virtual InputLabel const &at(InputMultiDiEdge const &e) const = 0; virtual OutputLabel const &at(OutputMultiDiEdge const &e) const = 0; virtual EdgeLabel const &at(MultiDiEdge const &e) const = 0; From 57994afe8e04c691b424488dbd2eb401743eaae3 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 2 Aug 2023 12:10:26 -0400 Subject: [PATCH 05/61] clean up generator codes and minor fix --- lib/compiler/src/machine_mapping.cc | 21 +- lib/compiler/test/test_generator.h | 269 +++++++++------------- lib/compiler/test/test_machine_mapping.cc | 20 +- lib/compiler/test/test_optimal_cost.cc | 8 +- lib/compiler/test/test_unity_algorithm.cc | 4 +- lib/utils/include/utils/containers.h | 5 + lib/utils/include/utils/graph/cow_ptr_t.h | 2 +- 7 files changed, 146 insertions(+), 183 deletions(-) diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index f6e7fea9d6..fd24ecdcf0 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -90,6 +90,10 @@ float estimate_cost( NOT_IMPLEMENTED(); } +void minimize_runtime(MachineMapping &m1, MachineMapping const &m2) { + minimize(m1, m2, MachineMappingRuntimeCmp{}); +} + struct OptimalCost { OptimalCost( SubParallelComputationGraph const &g, @@ -113,7 +117,7 @@ struct OptimalCost { optional const &source_machine_view; optional const &sink_machine_view; std::function( - PCGOperatorAttrs const &, MachineSpecification const &)> const + Operator const &, MachineSpecification const &)> const &allowed_machine_views; std::unordered_map &cached_subgraph_costs; @@ -183,7 +187,7 @@ struct OptimalCost { optional post_source_mv = contains(post_graph_sources, split_point) ? make_optional(mv) : nullopt; - minimize(optimal_result, + minimize_runtime(optimal_result, MachineMapping::sequential_combine( visit(OptimalCost(pre_graph, cost_estimator, @@ -200,8 +204,7 @@ struct OptimalCost { sink_machine_view, allowed_machine_views, cached_subgraph_costs), - post_decompn)), - MachineMappingRuntimeCmp{}); + post_decompn))); } return optimal_result; @@ -234,7 +237,7 @@ struct OptimalCost { decompn2)); for (auto const &resource_split : get_resource_split(resource)) { - minimize(optimal_result, + minimize_runtime(optimal_result, MachineMapping::parallel_combine( visit(OptimalCost(g1, cost_estimator, @@ -251,8 +254,7 @@ struct OptimalCost { sink_machine_view, allowed_machine_views, cached_subgraph_costs), - decompn2)), - MachineMappingRuntimeCmp{}); + decompn2))); } return optimal_result; @@ -277,9 +279,8 @@ struct OptimalCost { MachineMapping optimal_result = MachineMapping::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { std::unordered_map mv_map{{node, mv}}; - minimize(optimal_result, - {estimate_cost(g, cost_estimator, mv_map), mv_map}, - MachineMappingRuntimeCmp{}); + minimize_runtime(optimal_result, + {estimate_cost(g, cost_estimator, mv_map), mv_map}); } return optimal_result; } diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index 30fcd09d46..608cf4576f 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -5,8 +5,110 @@ #include "compiler/sub_parallel_computation_graph.h" #include "rapidcheck.h" +using namespace FlexFlow; + +enum class CompnType { SERIAL, PARALLEL }; + +struct Compn { + CompnType type; + int component1, component2; +}; + +int pop_component(std::set const &components, int value) { + value = value % components.size(); + auto it = components.begin(); + while (value--) { + it++; + } + int component = *it; + components.erase(it); + return component; +} + +/* + Generates a series-parallel graph according to the composition sequence + described by `composition`. A series-parallel graph can be generated as + follows: 1) Initially, we have E (E is the length of `composition`+1) + components, each containing a single edge; 2) In iteration `i`, we compose two + components (`composition[i].component1` and `composition[i].component2`): 2.1) + If `composition[i].type == SERIAL`, we merge the sink node of component1 and + the source node of component2; 2.2) If `composition[i].type == PARALLEL`, we + merge the source nodes and the sink nodes of two components. +*/ +MultiDiGraph generate_sp_graph(std::vector const &composition) { + std::set components; + disjoint_set node_id; // initially we have 2E nodes, and we will merge + // them during the iteration + std::vector src, + dst; // src and dst nodes for each edge before merging + std::vector srcIdx, + dstIdx; // src and dst node ports for each edge (I assume it is sufficient + // to make different edges have different NodePort. Correct me if + // I am wrong. @lockshaw) + AdjacencyMultiDiGraph g(0, 0, {}); + for (int i = 0; i <= composition.size(); ++i) { + components.insert(i); + src.push_back(g.add_node()); + dst.push_back(g.add_node()); + srcIdx.push_back(g.add_node_port()); + dstIdx.push_back(g.add_node_port()); + } + std::vector source_node = src, + sink_node = + dst; // initially each component has a single edge + + // We compute the src and dst nodes after merging for each edge before + // actually inserting the edges. + + for (Compn const &compn : composition) { + int c1 = pop_component(components, compn.component1); + int c2 = pop_component(components, compn.component2); + components.insert(c1); + if (compn.type == CompnType::SERIAL) { + node_id.m_union(sink_node[c1], source_node[c2]); + sink_node[c1] = sink_node[c2]; + } else { + node_id.m_union(source_node[c1], source_node[c2]); + node_id.m_union(sink_node[c1], sink_node[c2]); + } + } + + for (Node node : get_nodes(g)) { + if (node_id.find(node) != node) { + g.remove_node_unsafe(node); + } + } + + for (int i = 0; i < src.size(); ++i) { + g.add_edge(MultiDiEdge{src[i], dst[i], srcIdx[i], dstIdx[i]}); + } + + return g; +} + +template +OutputLabelledMultiDiGraph + generate_test_labelled_sp_graph() { + NOT_IMPLEMENTED(); + // Is there a way to construct a labelled graph from a MultiDiGraph and the + // labels? +} + +rc::Gen small_integer_generator() { + return gen::inRange(1, 4); +} + namespace rc { +template +struct Arbtrary { + static Gen< + std::enable_if, Tag>::value>::type> + arbitrary() { + return gen::construct(gen::arbitrary()); + } +}; + template <> struct Arbitrary { static Gen arbitrary() { @@ -18,9 +120,9 @@ template <> struct Arbitrary { static Gen arbitrary() { return gen::apply(make_1d_machine_view, - gen::construct(gen::nonZero()), - gen::construct(gen::nonZero()), - gen::inRange(1, 4)); + gen::arbitrary, + gen::arbitrary, + small_integer_generator()); } } @@ -36,158 +138,15 @@ struct Arbitrary { } template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::apply( - [](int num_nodes, - std::vector const &lhs, - std::vector const &rhs, - std::vector const &compn_type) { - auto g = - ::create>(); - - std::vector nodes, source, sink; - - for (int i = 0; i < num_nodes; ++i) { - Node new_node = g.add_node(Layer(NoopAttrs{}, nullopt)); - nodes.push_back(new_node); - source.push_back(new_node); - sink.push_back(new_node); - } - - for (int i = 0; i < lhs.size(); ++i) { - if (i >= rhs.size()) { - break; - } - if (i >= compn_type.size()) { - break; - } - - int n0 = lhs[i] % num_nodes, n1 = rhs[i] % num_nodes, - t = compn_type[i] % 2; - - Node source0 = source[n0], source1 = source[n1]; - Node sink0 = sink[n0], sink1 = sink[n1]; - - if (source0 == source1 && sink0 == sink1) { - continue; - } - - RC_ASSERT(source0 != source1 && sink0 != sink1); - - if (source0 == sink0 || t == 0) { - // sequential composition - g.add_edge(MultiDiOutput{sink0, NodePort(0)}, - MultiDiInput{source1, NodePort(0)}); - for (int j = 0; j < nodes.size(); ++j) { - if (source[j] == source1) { - source[j] = source0; - } - if (sink[j] == sink0) { - sink[j] = sink1; - } - } - } else { - // parallel composition - g.add_edge(MultiDiOutput{source0, NodePort(0)}, - MultiDiInput{source1, NodePort(0)}); - g.add_edge(MultiDiOutput{sink1, NodePort(0)}, - MultiDiInput{sink0, NodePort(0)}); - for (int j = 0; j < nodes.size(); ++j) { - if (source[j] == source1) { - source[j] = source0; - } - if (sink[j] == sink1) { - sink[j] = sink0; - } - } - } - } - - return ComputationGraph(g); - }, - gen::inRange(1, 200), - gen::arbitrary>(), - gen::arbitrary>(), - gen::arbitrary>()); - } -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::apply( - [](int num_nodes, - std::vector const &lhs, - std::vector const &rhs, - std::vector const &compn_type) { - auto g = OutputLabelledMultiDiGraph::create< - UnorderedOutputLabelledMultiDiGraph>(); - - std::vector nodes, source, sink; - - for (int i = 0; i < num_nodes; ++i) { - Node new_node = g.add_node(Operator(NoopAttrs{}, nullopt)); - nodes.push_back(new_node); - source.push_back(new_node); - sink.push_back(new_node); - } - - for (int i = 0; i < lhs.size(); ++i) { - if (i >= rhs.size()) { - break; - } - if (i >= compn_type.size()) { - break; - } - - int n0 = lhs[i] % num_nodes, n1 = rhs[i] % num_nodes, - t = compn_type[i] % 2; - - Node source0 = source[n0], source1 = source[n1]; - Node sink0 = sink[n0], sink1 = sink[n1]; - - if (source0 == source1 && sink0 == sink1) { - continue; - } - - RC_ASSERT(source0 != source1 && sink0 != sink1); - - if (source0 == sink0 || t == 0) { - // sequential composition - g.add_edge(MultiDiOutput{sink0, NodePort(0)}, - MultiDiInput{source1, NodePort(0)}); - for (int j = 0; j < nodes.size(); ++j) { - if (source[j] == source1) { - source[j] = source0; - } - if (sink[j] == sink0) { - sink[j] = sink1; - } - } - } else { - // parallel composition - g.add_edge(MultiDiOutput{source0, NodePort(0)}, - MultiDiInput{source1, NodePort(0)}); - g.add_edge(MultiDiOutput{sink1, NodePort(0)}, - MultiDiInput{sink0, NodePort(0)}); - for (int j = 0; j < nodes.size(); ++j) { - if (source[j] == source1) { - source[j] = source0; - } - if (sink[j] == sink1) { - sink[j] = sink0; - } - } - } - } - - return ParallelComputationGraph(g); - }, - gen::inRange(1, 200), - gen::arbitrary>(), - gen::arbitrary>(), - gen::arbitrary>()); +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), + gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), + gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), + gen::set(&MachineSpecification::inter_node_bandwidth, gen::nonZero()), + gen::set(&MachineSpecification::intra_node_bandwidth, gen::nonZero()) + ); } } diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/test_machine_mapping.cc index f36e5781bb..fe72d549cc 100644 --- a/lib/compiler/test/test_machine_mapping.cc +++ b/lib/compiler/test/test_machine_mapping.cc @@ -1,37 +1,39 @@ #include "doctest.h" #include "test_generator.h" +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { + return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); +} + TEST_CASE("MachineMapping::sequential_combine") { rc::check([](MachineMapping const &mp0, MachineMapping const &mp1) { - RC_PRE(are_disjoint(keys(mp0), keys(mp1))); + RC_PRE(nodes_are_disjoint(mp0, mp1)); MachineMapping comb = MachineMapping::sequential_combine(mp0, mp1); RC_ASSERT(comb.runtime == mp0.runtime + mp1.runtime); RC_ASSERT(comb.machine_views.size() == mp0.machine_views.size() + mp1.machine_views.size()); - for (auto p : mp0.machine_views) { - RC_ASSERT(p.second == comb.machine_views.at(p.first)); - } + RC_ASSERT(is_submap(comb.machine_views, mp0.machine_views)); + RC_ASSERT(is_submap(comb.machine_views, mp1.machine_views)); }); } TEST_CASE("MachineMapping::parallel_combine") { rc::check([](MachineMapping const &mp0, MachineMapping const &mp1) { - RC_PRE(are_disjoint(keys(mp0), keys(mp1))); + RC_PRE(nodes_are_disjoint(mp0, mp1)); MachineMapping comb = MachineMapping::parallel_combine(mp0, mp1); RC_ASSERT(comb.runtime == std::max(mp0.runtime, mp1.runtime)); RC_ASSERT(comb.machine_views.size() == mp0.machine_views.size() + mp1.machine_views.size()); - for (auto p : mp0.machine_views) { - RC_ASSERT(p.second == comb.machine_views.at(p.first)); - } + RC_ASSERT(is_submap(comb.machine_views, mp0.machine_views)); + RC_ASSERT(is_submap(comb.machine_views, mp1.machine_views)); }); } -TEST_CASE("MachieMapping::infinity") { +TEST_CASE("MachineMapping::infinity") { rc::check([](MachineMapping const &mp) { RC_ASSERT(mp.runtime <= MachineMapping::infinity().runtime); }); diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc index a1207ede00..7b68433f99 100644 --- a/lib/compiler/test/test_optimal_cost.cc +++ b/lib/compiler/test/test_optimal_cost.cc @@ -2,7 +2,7 @@ #include "test_generator.h" TEST_CASE("optimal_cost") { - rc::check([](ParallelComputationGraph const &g) { + rc::check([](ParallelComputationGraph const &g, MachineSpecification const &machine_spec) { std::unordered_map cached_subgraph_costs; MachineMapping machine_mapping = optimal_cost( g, @@ -10,11 +10,9 @@ TEST_CASE("optimal_cost") { return std::unordered_set{make_1d_machine_view(0, 1, 1)}; }, TestCostEstimator{}, - MachineSpecification{1, 1, 4, 0.1, 0.2}, + machine_spec, cached_subgraph_costs); RC_ASSERT(machine_mapping.runtime > 0); - for (auto node : get_nodes(g)) { - RC_ASSERT(contains_key(machine_mapping.machine_views, node)); - } + RC_ASSERT(keys(machine_mapping.machine_views) == get_nodes(g)); }); } \ No newline at end of file diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/test_unity_algorithm.cc index 659a4d0a88..1ecf96b662 100644 --- a/lib/compiler/test/test_unity_algorithm.cc +++ b/lib/compiler/test/test_unity_algorithm.cc @@ -18,8 +18,6 @@ TEST_CASE("graph_optimize") { OptimizerConfig{alpha, budget, threshold, max_num_ops}); RC_ASSERT(get_nodes(s.pcg).size() > 0); RC_ASSERT(s.machine_mapping.runtime > 0); - for (auto node : get_nodes(s.pcg)) { - RC_ASSERT(contains_key(s.machine_mapping.machine_views, node)); - } + RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); }); } \ No newline at end of file diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 05ef6ed0e0..366fd7af1d 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -206,6 +206,11 @@ std::unordered_map filter_values(std::unordered_map const &m, return result; } +template +bool is_submap(std::unordered_map const &m, std::unordered_map const &sub) { + return restrict_keys(m, keys(sub)) == sub; +} + template std::unordered_set keys(C const &c) { std::unordered_set result; diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 92ef993cab..3dce6bc8ce 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -11,7 +11,7 @@ namespace FlexFlow { template struct cow_ptr_t { // static_assert(is_clonable::value, - // "cow_ptr_t requires the type to have a clone() method"); + // "cow_ptr_t requires the type to have a clone() method"); // TODO: https://github.com/flexflow/FlexFlow/issues/909#issue-1833470024 cow_ptr_t(std::shared_ptr ptr) : ptr(std::move(ptr)) {} cow_ptr_t(std::unique_ptr ptr) : ptr(std::move(ptr)) {} From 2be3f16f5151e91cc59112f14f3a34adc0b4d716 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 2 Aug 2023 12:17:14 -0400 Subject: [PATCH 06/61] format --- lib/compiler/src/machine_mapping.cc | 75 ++++++------ lib/compiler/test/test_generator.h | 15 +-- lib/compiler/test/test_optimal_cost.cc | 3 +- lib/compiler/test/test_unity_algorithm.cc | 2 +- lib/utils/include/utils/bidict.h | 2 +- lib/utils/include/utils/containers.h | 4 +- lib/utils/include/utils/graph/cow_ptr_t.h | 4 +- .../labelled_downward_open_interfaces.h | 3 +- .../utils/graph/labelled/open_algorithms.h | 35 +++++- .../include/utils/graph/labelled/open_views.h | 107 +++++++++++++----- .../utils/graph/open_graph_interfaces.h | 34 +++--- lib/utils/include/utils/graph/query_set.h | 5 +- lib/utils/include/utils/test_types.h | 83 +++++++------- lib/utils/include/utils/variant.h | 12 +- lib/utils/src/graph/algorithms.cc | 20 ++-- 15 files changed, 248 insertions(+), 156 deletions(-) diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index fd24ecdcf0..dd54beda2c 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -61,8 +61,7 @@ GraphSplit } std::pair - apply_split(SubParallelComputationGraph const &g, - GraphSplit const &split) { + apply_split(SubParallelComputationGraph const &g, GraphSplit const &split) { OpenMultiDiGraphView g1 = get_subgraph(g, split.first); OpenMultiDiGraphView g2 = get_subgraph(g, split.second); @@ -166,7 +165,7 @@ struct OptimalCost { auto subgraphs = apply_split(g, get_graph_split(pre_decompn, post_decompn)); SubParallelComputationGraph pre_graph = subgraphs.first, - post_graph = subgraphs.second; + post_graph = subgraphs.second; std::unordered_set pre_graph_sinks = get_closed_sinks(pre_graph); std::unordered_set post_graph_sources = @@ -188,23 +187,23 @@ struct OptimalCost { contains(post_graph_sources, split_point) ? make_optional(mv) : nullopt; minimize_runtime(optimal_result, - MachineMapping::sequential_combine( - visit(OptimalCost(pre_graph, - cost_estimator, - resource, - source_machine_view, - pre_sink_mv, - allowed_machine_views, - cached_subgraph_costs), - pre_decompn), - visit(OptimalCost(post_graph, - cost_estimator, - resource, - post_source_mv, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), - post_decompn))); + MachineMapping::sequential_combine( + visit(OptimalCost(pre_graph, + cost_estimator, + resource, + source_machine_view, + pre_sink_mv, + allowed_machine_views, + cached_subgraph_costs), + pre_decompn), + visit(OptimalCost(post_graph, + cost_estimator, + resource, + post_source_mv, + sink_machine_view, + allowed_machine_views, + cached_subgraph_costs), + post_decompn))); } return optimal_result; @@ -238,23 +237,23 @@ struct OptimalCost { for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime(optimal_result, - MachineMapping::parallel_combine( - visit(OptimalCost(g1, - cost_estimator, - resource_split.first, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), - decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource_split.second, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), - decompn2))); + MachineMapping::parallel_combine( + visit(OptimalCost(g1, + cost_estimator, + resource_split.first, + source_machine_view, + sink_machine_view, + allowed_machine_views, + cached_subgraph_costs), + decompn1), + visit(OptimalCost(g2, + cost_estimator, + resource_split.second, + source_machine_view, + sink_machine_view, + allowed_machine_views, + cached_subgraph_costs), + decompn2))); } return optimal_result; @@ -280,7 +279,7 @@ struct OptimalCost { for (auto mv : allowed_machine_views(g.at(node), resource)) { std::unordered_map mv_map{{node, mv}}; minimize_runtime(optimal_result, - {estimate_cost(g, cost_estimator, mv_map), mv_map}); + {estimate_cost(g, cost_estimator, mv_map), mv_map}); } return optimal_result; } diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index 608cf4576f..83166c63c1 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -40,7 +40,7 @@ MultiDiGraph generate_sp_graph(std::vector const &composition) { disjoint_set node_id; // initially we have 2E nodes, and we will merge // them during the iteration std::vector src, - dst; // src and dst nodes for each edge before merging + dst; // src and dst nodes for each edge before merging std::vector srcIdx, dstIdx; // src and dst node ports for each edge (I assume it is sufficient // to make different edges have different NodePort. Correct me if @@ -141,12 +141,13 @@ template <> struct Arbitrary { static Gen arbitrary() { return gen::build( - gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), - gen::set(&MachineSpecification::inter_node_bandwidth, gen::nonZero()), - gen::set(&MachineSpecification::intra_node_bandwidth, gen::nonZero()) - ); + gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), + gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), + gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), + gen::set(&MachineSpecification::inter_node_bandwidth, + gen::nonZero()), + gen::set(&MachineSpecification::intra_node_bandwidth, + gen::nonZero())); } } diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc index 7b68433f99..3a16ae52d9 100644 --- a/lib/compiler/test/test_optimal_cost.cc +++ b/lib/compiler/test/test_optimal_cost.cc @@ -2,7 +2,8 @@ #include "test_generator.h" TEST_CASE("optimal_cost") { - rc::check([](ParallelComputationGraph const &g, MachineSpecification const &machine_spec) { + rc::check([](ParallelComputationGraph const &g, + MachineSpecification const &machine_spec) { std::unordered_map cached_subgraph_costs; MachineMapping machine_mapping = optimal_cost( g, diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/test_unity_algorithm.cc index 1ecf96b662..8be65eed94 100644 --- a/lib/compiler/test/test_unity_algorithm.cc +++ b/lib/compiler/test/test_unity_algorithm.cc @@ -19,5 +19,5 @@ TEST_CASE("graph_optimize") { RC_ASSERT(get_nodes(s.pcg).size() > 0); RC_ASSERT(s.machine_mapping.runtime > 0); RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); - }); + }); } \ No newline at end of file diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index c600f40790..797d3c1758 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_UTILS_BIDICT_H #define _FLEXFLOW_UTILS_BIDICT_H -#include #include +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 366fd7af1d..dcbaced8fb 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -4,7 +4,6 @@ #include "bidict.h" #include "invoke.h" #include "optional.h" -#include "type_traits_core.h" #include "required_core.h" #include "type_traits_core.h" #include @@ -207,7 +206,8 @@ std::unordered_map filter_values(std::unordered_map const &m, } template -bool is_submap(std::unordered_map const &m, std::unordered_map const &sub) { +bool is_submap(std::unordered_map const &m, + std::unordered_map const &sub) { return restrict_keys(m, keys(sub)) == sub; } diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 3dce6bc8ce..7b888cf6e9 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -11,7 +11,9 @@ namespace FlexFlow { template struct cow_ptr_t { // static_assert(is_clonable::value, - // "cow_ptr_t requires the type to have a clone() method"); // TODO: https://github.com/flexflow/FlexFlow/issues/909#issue-1833470024 + // "cow_ptr_t requires the type to have a clone() method"); // + // TODO: + // https://github.com/flexflow/FlexFlow/issues/909#issue-1833470024 cow_ptr_t(std::shared_ptr ptr) : ptr(std::move(ptr)) {} cow_ptr_t(std::unique_ptr ptr) : ptr(std::move(ptr)) {} diff --git a/lib/utils/include/utils/graph/labelled/labelled_downward_open_interfaces.h b/lib/utils/include/utils/graph/labelled/labelled_downward_open_interfaces.h index e77787294d..98b579df13 100644 --- a/lib/utils/include/utils/graph/labelled/labelled_downward_open_interfaces.h +++ b/lib/utils/include/utils/graph/labelled/labelled_downward_open_interfaces.h @@ -14,7 +14,8 @@ struct ILabelledDownwardOpenMultiDiGraphView public IDownwardOpenMultiDiGraphView { virtual ~ILabelledDownwardOpenMultiDiGraphView() = default; - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const final { + std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const final { return this->query_edges(static_cast(q)); } diff --git a/lib/utils/include/utils/graph/labelled/open_algorithms.h b/lib/utils/include/utils/graph/labelled/open_algorithms.h index 89f8674373..4b54aeeece 100644 --- a/lib/utils/include/utils/graph/labelled/open_algorithms.h +++ b/lib/utils/include/utils/graph/labelled/open_algorithms.h @@ -150,16 +150,39 @@ ResultType get_subgraph(LabelledOpenMultiDiGraph(as_view(g), nodes); } -template +template LabelledUpwardOpenMultiDiGraphView -as_upward_open(LabelledOpenMultiDiGraphView const &g) { - return LabelledUpwardOpenMultiDiGraphView::template create>(g); + as_upward_open(LabelledOpenMultiDiGraphView const &g) { + return LabelledUpwardOpenMultiDiGraphView:: + template create>(g); } -template +template LabelledDownwardOpenMultiDiGraphView -as_downward_open(LabelledOpenMultiDiGraphView const &g) { - return LabelledDownwardOpenMultiDiGraphView::template create>(g); + as_downward_open(LabelledOpenMultiDiGraphView const &g) { + return LabelledDownwardOpenMultiDiGraphView:: + template create>( + g); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 4243e6ba43..4c58af9386 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -9,12 +9,12 @@ #include "labelled_upward_open_interfaces.h" #include "standard_labelled.h" #include "utils/exception.h" +#include "utils/graph/algorithms.h" #include "utils/graph/multidiedge.h" #include "utils/graph/open_graph_interfaces.h" #include "utils/graph/open_graphs.h" #include "utils/type_traits.h" #include "utils/visitable.h" -#include "utils/graph/algorithms.h" namespace FlexFlow { @@ -51,48 +51,97 @@ struct LabelledDownwardMultiDiSubgraphView { std::unordered_set const &); }; -template -struct ViewLabelledOpenMultiDiGraphAsUpwardOpen : public ILabelledUpwardOpenMultiDiGraphView { +template +struct ViewLabelledOpenMultiDiGraphAsUpwardOpen + : public ILabelledUpwardOpenMultiDiGraphView { public: - using InputType = LabelledOpenMultiDiGraphView; + using InputType = LabelledOpenMultiDiGraphView; - explicit ViewLabelledOpenMultiDiGraphAsUpwardOpen(InputType const &g) : g(g) { } + explicit ViewLabelledOpenMultiDiGraphAsUpwardOpen(InputType const &g) + : g(g) {} - std::unordered_set query_nodes(NodeQuery const &q) const override { return this->g.query_nodes(q); } + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } - std::unordered_set query_edges(UpwardOpenMultiDiEdgeQuery const &q) const override { - return value_all(narrow(this->g.query_edges(q))); + std::unordered_set + query_edges(UpwardOpenMultiDiEdgeQuery const &q) const override { + return value_all( + narrow(this->g.query_edges(q))); + } + + NodeLabel const &at(Node const &n) const override { + return this->g.at(n); + } + InputLabel const &at(InputMultiDiEdge const &e) const override { + return this->g.at(e); + } + EdgeLabel const &at(MultiDiEdge const &e) const override { + return this->g.at(e); } - NodeLabel const &at(Node const &n) const override { return this->g.at(n); } - InputLabel const &at(InputMultiDiEdge const &e) const override { return this->g.at(e); } - EdgeLabel const &at(MultiDiEdge const &e) const override { return this->g.at(e); } private: InputType g; }; -CHECK_NOT_ABSTRACT(ViewLabelledOpenMultiDiGraphAsUpwardOpen); +CHECK_NOT_ABSTRACT( + ViewLabelledOpenMultiDiGraphAsUpwardOpen); -template -struct ViewLabelledOpenMultiDiGraphAsDownwardOpen : public ILabelledDownwardOpenMultiDiGraphView { +template +struct ViewLabelledOpenMultiDiGraphAsDownwardOpen + : public ILabelledDownwardOpenMultiDiGraphView { public: - using InputType = LabelledOpenMultiDiGraphView; + using InputType = LabelledOpenMultiDiGraphView; - explicit ViewLabelledOpenMultiDiGraphAsDownwardOpen(InputType const &g) : g(g) { } + explicit ViewLabelledOpenMultiDiGraphAsDownwardOpen(InputType const &g) + : g(g) {} - std::unordered_set query_nodes(NodeQuery const &q) const override { return this->g.query_nodes(q); } + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } - std::unordered_set query_edges(DownwardOpenMultiDiEdgeQuery const &q) const override { - return value_all(narrow(this->g.query_edges(q))); + std::unordered_set + query_edges(DownwardOpenMultiDiEdgeQuery const &q) const override { + return value_all( + narrow(this->g.query_edges(q))); } - NodeLabel const &at(Node const &n) const override { return this->g.at(n); } - OutputLabel const &at(OutputMultiDiEdge const &e) const override { return this->g.at(e); } - EdgeLabel const &at(MultiDiEdge const &e) const override { return this->g.at(e); } + NodeLabel const &at(Node const &n) const override { + return this->g.at(n); + } + OutputLabel const &at(OutputMultiDiEdge const &e) const override { + return this->g.at(e); + } + EdgeLabel const &at(MultiDiEdge const &e) const override { + return this->g.at(e); + } private: InputType g; }; -CHECK_NOT_ABSTRACT(ViewLabelledOpenMultiDiGraphAsUpwardOpen); +CHECK_NOT_ABSTRACT( + ViewLabelledOpenMultiDiGraphAsUpwardOpen); template ; - LabelledOpenMultiDiSubgraphView(LabelledOpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) - { } + LabelledOpenMultiDiSubgraphView( + LabelledOpenMultiDiGraphView const &g, + std::unordered_set const &nodes) + : g(g), nodes(nodes) {} std::unordered_set query_edges(UpwardOpenMultiDiEdgeQuery const &q) const override { @@ -186,7 +239,7 @@ struct LabelledOpenMultiDiSubgraphView query_nodes(NodeQuery const &q) const override { return static_cast(this->g).query_nodes(q); - + NOT_IMPLEMENTED(); } diff --git a/lib/utils/include/utils/graph/open_graph_interfaces.h b/lib/utils/include/utils/graph/open_graph_interfaces.h index 6d05c21f37..3d7e66119d 100644 --- a/lib/utils/include/utils/graph/open_graph_interfaces.h +++ b/lib/utils/include/utils/graph/open_graph_interfaces.h @@ -88,12 +88,15 @@ FF_VISITABLE_STRUCT(OpenMultiDiEdgeQuery, struct DownwardOpenMultiDiEdgeQuery { DownwardOpenMultiDiEdgeQuery() = delete; DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query, - MultiDiEdgeQuery const &standard_edge_query) - : output_edge_query(output_edge_query), standard_edge_query(standard_edge_query) { } - DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query) - : DownwardOpenMultiDiEdgeQuery(output_edge_query, MultiDiEdgeQuery::none()) { } - DownwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &standard_edge_query) - : DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery::all(), standard_edge_query) { }; + MultiDiEdgeQuery const &standard_edge_query) + : output_edge_query(output_edge_query), + standard_edge_query(standard_edge_query) {} + DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query) + : DownwardOpenMultiDiEdgeQuery(output_edge_query, + MultiDiEdgeQuery::none()) {} + DownwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &standard_edge_query) + : DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery::all(), + standard_edge_query){}; operator OpenMultiDiEdgeQuery() const { NOT_IMPLEMENTED(); @@ -103,8 +106,8 @@ struct DownwardOpenMultiDiEdgeQuery { MultiDiEdgeQuery standard_edge_query; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(DownwardOpenMultiDiEdgeQuery, - output_edge_query, - standard_edge_query); + output_edge_query, + standard_edge_query); struct UpwardOpenMultiDiEdgeQuery { UpwardOpenMultiDiEdgeQuery() = delete; @@ -133,10 +136,11 @@ struct IDownwardOpenMultiDiGraphView : public IOpenMultiDiGraphView { virtual std::unordered_set query_edges(DownwardOpenMultiDiEdgeQuery const &) const = 0; - std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const final { + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const final { return widen( - this->query_edges(DownwardOpenMultiDiEdgeQuery{ q.output_edge_query, q.standard_edge_query }) - ); + this->query_edges(DownwardOpenMultiDiEdgeQuery{q.output_edge_query, + q.standard_edge_query})); } }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDownwardOpenMultiDiGraphView); @@ -145,10 +149,10 @@ struct IUpwardOpenMultiDiGraphView : public IOpenMultiDiGraphView { virtual std::unordered_set query_edges(UpwardOpenMultiDiEdgeQuery const &) const = 0; - std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const final { - return widen( - this->query_edges(UpwardOpenMultiDiEdgeQuery{ q.input_edge_query, q.standard_edge_query }) - ); + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const final { + return widen(this->query_edges( + UpwardOpenMultiDiEdgeQuery{q.input_edge_query, q.standard_edge_query})); } }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUpwardOpenMultiDiGraphView); diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index 656d64d531..e7f4f60812 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -20,9 +20,8 @@ struct query_set { query_set(optional> const &) { NOT_IMPLEMENTED(); } - query_set(std::initializer_list const &l) - : query_set(std::unordered_set{l}) - { } + query_set(std::initializer_list const &l) + : query_set(std::unordered_set{l}) {} friend bool operator==(query_set const &lhs, query_set const &rhs) { return lhs.value == rhs.value; diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h index 2b7876496b..4fafb4a8fb 100644 --- a/lib/utils/include/utils/test_types.h +++ b/lib/utils/include/utils/test_types.h @@ -7,68 +7,66 @@ namespace FlexFlow { namespace test_types { -enum capability { - HASHABLE, - EQ, - CMP, - DEFAULT_CONSTRUCTIBLE, - COPYABLE -}; +enum capability { HASHABLE, EQ, CMP, DEFAULT_CONSTRUCTIBLE, COPYABLE }; -template -struct capability_implies : std::false_type { }; +template +struct capability_implies : std::false_type {}; template <> -struct capability_implies : std::true_type { }; +struct capability_implies : std::true_type {}; template -struct capability_implies : std::true_type { }; - +struct capability_implies : std::true_type {}; -template struct has_capability; +template +struct has_capability; -template -struct has_capability : disjunction< - capability_implies, - has_capability> { }; +template +struct has_capability + : disjunction, + has_capability> {}; template -struct has_capability : std::false_type { }; +struct has_capability : std::false_type {}; -template +template struct test_type_t { - template + template using supports = conjunction...>; - template::value, bool>::type = true> + template ::value, + bool>::type = true> test_type_t(); - template::value, bool>::type = true> + template ::value, + bool>::type = true> test_type_t() = delete; - template::value, bool>::type = true> + template < + typename std::enable_if::value, bool>::type = true> test_type_t(test_type_t const &); - template::value, bool>::type = true> + template < + typename std::enable_if::value, bool>::type = true> test_type_t(test_type_t const &) = delete; typename std::enable_if::value, bool>::type - operator==(test_type_t const &) const; + operator==(test_type_t const &) const; typename std::enable_if::value, bool>::type - operator!=(test_type_t const &) const; + operator!=(test_type_t const &) const; typename std::enable_if::value, bool>::type - operator<(test_type_t const &) const; + operator<(test_type_t const &) const; typename std::enable_if::value, bool>::type - operator>(test_type_t const &) const; + operator>(test_type_t const &) const; typename std::enable_if::value, bool>::type - operator<=(test_type_t const &) const; + operator<=(test_type_t const &) const; typename std::enable_if::value, bool>::type - operator>=(test_type_t const &) const; + operator>=(test_type_t const &) const; }; using no_eq = test_type_t<>; @@ -76,20 +74,27 @@ using eq = test_type_t; using cmp = test_type_t; using hash_cmp = test_type_t; -} -} +} // namespace test_types +} // namespace FlexFlow namespace std { -template <::FlexFlow::test_types::capability ...CAPABILITIES> //, typename = typename std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, bool>::type> -struct hash<::FlexFlow::test_types::test_type_t> { +template < + ::FlexFlow::test_types:: + capability... CAPABILITIES> //, typename = typename + //std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, + //bool>::type> + struct hash< + ::FlexFlow::test_types::test_type_t< + CAPABILITIES...>> { typename std::enable_if< - ::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE, CAPABILITIES...>::value, - size_t - >::type - operator()(::FlexFlow::test_types::test_type_t const &) const; + ::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE, + CAPABILITIES...>::value, + size_t>::type + operator()( + ::FlexFlow::test_types::test_type_t const &) const; }; -} +} // namespace std #endif diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index 61758c79d9..a03a781c75 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -174,9 +174,8 @@ template < typename Container, typename VariantIn = typename Container::value_type, typename = std::enable_if::value>> -auto widen(Container const &c) - -> decltype(transform(c, std::declval>())) -{ +auto widen(Container const &c) -> decltype(transform( + c, std::declval>())) { return transform(c, [](VariantIn const &i) { return widen(i); }); } @@ -193,13 +192,12 @@ template < typename Container, typename VariantIn = typename Container::value_type, typename = std::enable_if::value>> -auto narrow(Container const &c) - -> decltype(transform(c, std::declval(VariantIn const &)>>())) -{ +auto narrow(Container const &c) -> decltype(transform( + c, + std::declval(VariantIn const &)>>())) { return transform(c, [](VariantIn const &i) { return narrow(i); }); } - template get_cut(OpenMultiDiGraphView const &g, return keys(get_edge_splits(g, s)); } -Node get_src_node(MultiDiEdge const &) { NOT_IMPLEMENTED(); } -Node get_dst_node(MultiDiEdge const &) { NOT_IMPLEMENTED(); } -Node get_src_node(InputMultiDiEdge const &) { NOT_IMPLEMENTED(); } -Node get_dst_node(OutputMultiDiEdge const &) { NOT_IMPLEMENTED(); } +Node get_src_node(MultiDiEdge const &) { + NOT_IMPLEMENTED(); +} +Node get_dst_node(MultiDiEdge const &) { + NOT_IMPLEMENTED(); +} +Node get_src_node(InputMultiDiEdge const &) { + NOT_IMPLEMENTED(); +} +Node get_dst_node(OutputMultiDiEdge const &) { + NOT_IMPLEMENTED(); +} UndirectedGraphView get_subgraph(UndirectedGraphView const &g, std::unordered_set const &nodes) { From 4c7e56eee94e813ccc9afc7e3bc4b919728aeb53 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 2 Aug 2023 12:24:16 -0400 Subject: [PATCH 07/61] format --- lib/utils/include/utils/test_types.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h index 4fafb4a8fb..514d030b6f 100644 --- a/lib/utils/include/utils/test_types.h +++ b/lib/utils/include/utils/test_types.h @@ -82,11 +82,9 @@ namespace std { template < ::FlexFlow::test_types:: capability... CAPABILITIES> //, typename = typename - //std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, - //bool>::type> - struct hash< - ::FlexFlow::test_types::test_type_t< - CAPABILITIES...>> { + // std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, + // bool>::type> +struct hash<::FlexFlow::test_types::test_type_t> { typename std::enable_if< ::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE, CAPABILITIES...>::value, From edb1c588cab76ffe31147523c4676f77d1332702 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 9 Aug 2023 18:51:52 -0400 Subject: [PATCH 08/61] serial parallel composition --- lib/compiler/test/test_generator.h | 211 ++++++++++++++++--------- lib/compiler/test/test_optimal_cost.cc | 21 ++- 2 files changed, 148 insertions(+), 84 deletions(-) diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index 83166c63c1..c258ad46d0 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -7,82 +7,93 @@ using namespace FlexFlow; -enum class CompnType { SERIAL, PARALLEL }; - -struct Compn { - CompnType type; - int component1, component2; -}; +std::unordered_map parallel_extend(MultiDiGraph &g, + MultiDiGraph const &ext) { + std::unordered_map node_map; + std::unordered_map node_port_map; + for (Node const &node : get_nodes(ext)) { + node_map[node] = g.add_node(); + } + for (NodePort const &node_port : get_node_ports(ext)) { + node_port_map[node_port] = g.add_node_port(); + } + for (MultiDiEdge const &edge : get_edges(ext)) { + g.add_edge(MultiDiEdge{node_map[edge.src], + node_map[edge.dst], + node_map[edge.srcIdx], + node_map[edge.dstIdx]}); + } + return node_map; +} -int pop_component(std::set const &components, int value) { - value = value % components.size(); - auto it = components.begin(); - while (value--) { - it++; - } - int component = *it; - components.erase(it); - return component; -} - -/* - Generates a series-parallel graph according to the composition sequence - described by `composition`. A series-parallel graph can be generated as - follows: 1) Initially, we have E (E is the length of `composition`+1) - components, each containing a single edge; 2) In iteration `i`, we compose two - components (`composition[i].component1` and `composition[i].component2`): 2.1) - If `composition[i].type == SERIAL`, we merge the sink node of component1 and - the source node of component2; 2.2) If `composition[i].type == PARALLEL`, we - merge the source nodes and the sink nodes of two components. -*/ -MultiDiGraph generate_sp_graph(std::vector const &composition) { - std::set components; - disjoint_set node_id; // initially we have 2E nodes, and we will merge - // them during the iteration - std::vector src, - dst; // src and dst nodes for each edge before merging - std::vector srcIdx, - dstIdx; // src and dst node ports for each edge (I assume it is sufficient - // to make different edges have different NodePort. Correct me if - // I am wrong. @lockshaw) - AdjacencyMultiDiGraph g(0, 0, {}); - for (int i = 0; i <= composition.size(); ++i) { - components.insert(i); - src.push_back(g.add_node()); - dst.push_back(g.add_node()); - srcIdx.push_back(g.add_node_port()); - dstIdx.push_back(g.add_node_port()); - } - std::vector source_node = src, - sink_node = - dst; // initially each component has a single edge - - // We compute the src and dst nodes after merging for each edge before - // actually inserting the edges. - - for (Compn const &compn : composition) { - int c1 = pop_component(components, compn.component1); - int c2 = pop_component(components, compn.component2); - components.insert(c1); - if (compn.type == CompnType::SERIAL) { - node_id.m_union(sink_node[c1], source_node[c2]); - sink_node[c1] = sink_node[c2]; - } else { - node_id.m_union(source_node[c1], source_node[c2]); - node_id.m_union(sink_node[c1], sink_node[c2]); +std::unordered_map serial_extend(MultiDiGraph &g, + MultiDiGraph const &ext) { + std::unordered_set original_sinks = get_sinks(g); + std::unordered_map node_map = parallel_extend(g, ext); + for (Node const &node1 : original_sinks) { + for (Node const &node2 : get_sources(ext)) { + g.add_edge(MultiDiEdge{ + node1, node_map[node2], g.add_node_port(), g.add_node_port()}); } } + return node_map; +} - for (Node node : get_nodes(g)) { - if (node_id.find(node) != node) { - g.remove_node_unsafe(node); - } +MultiDiGraph serial_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2) { + MultiDiGraph g = g1; + serial_extend(g, g2); + return g; +} + +MultiDiGraph parallel_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2) { + MultiDiGraph g = g1; + parallel_extend(g, g2); + return g; +} + +struct MultiDiGraphFromSPDecomposition { + template + MultiDiGraph operator()(T const &t) { + return multidigraph_from_sp_decomposition(t); } +}; + +MultiDiGraph multidigraph_from_sp_decomposition( + SerialParallelDecomposition const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition( + variant const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition( + variant const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); +} - for (int i = 0; i < src.size(); ++i) { - g.add_edge(MultiDiEdge{src[i], dst[i], srcIdx[i], dstIdx[i]}); +MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { + MultiDiGraph g = MultiDiGraph::create(); + for (auto child : serial.children) { + serial_extend(g, multidigraph_from_sp_decomposition(child)); } + return g; +} + +MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { + MultiDiGraph g = MultiDiGraph::create(); + for (auto child : parallel.children) { + parallel_extend(g, multidigraph_from_sp_decomposition(child)); + } + return g; +} +MultiDiGraph multidigraph_from_sp_decomposition(Node const &Node) { + MultiDiGraph g = MultiDiGraph::create(); + g.add_node(); return g; } @@ -100,8 +111,63 @@ rc::Gen small_integer_generator() { namespace rc { +Gen serialParallelMultiDiGraph() { + return gen::map(gen::arbitrary(), + multidigraph_from_sp_decomposition); +} + +template <> +struct Arbitrary> { + static Gen> arbitrary() { + return gen::mapcat(gen::arbitrary(), [](bool is_node) { + return is_node ? gen::arbitrary() : gen::arbitrary(); + }); + } +}; + +template <> +struct Arbitrary> { + static Gen> arbitrary() { + return gen::mapcat(gen::arbitrary(), [](bool is_node) { + return is_node ? gen::arbitrary() : gen::arbitrary(); + }); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&Serial::children, + gen::container>>( + gen::arbitrary>()))); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&Parallel::children, + gen::container>>( + gen::arbitrary>()))); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::mapcat(gen::arbitrary(), [](bool is_serial) { + return is_serial ? gen::construct( + gen::arbitrary()) + : gen::construct( + gen::arbitrary()); + }); + } +}; + template -struct Arbtrary { +struct Arbitrary { static Gen< std::enable_if, Tag>::value>::type> arbitrary() { @@ -109,13 +175,6 @@ struct Arbtrary { } }; -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::construct(gen::arbitrary()); - } -}; - template <> struct Arbitrary { static Gen arbitrary() { diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc index 3a16ae52d9..8993eb4228 100644 --- a/lib/compiler/test/test_optimal_cost.cc +++ b/lib/compiler/test/test_optimal_cost.cc @@ -1,18 +1,23 @@ #include "test_cost_estimator.h" #include "test_generator.h" +/* +Tests whether optimal_cost can give a valid result given random PCG, trivial +allowed machine views, trivial cost estimator and random machine specification. +*/ TEST_CASE("optimal_cost") { + auto test_allowed_machine_views = [](Operator const &, + MachineSpecification const &) { + return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + }; rc::check([](ParallelComputationGraph const &g, MachineSpecification const &machine_spec) { std::unordered_map cached_subgraph_costs; - MachineMapping machine_mapping = optimal_cost( - g, - [](Operator const &, MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(0, 1, 1)}; - }, - TestCostEstimator{}, - machine_spec, - cached_subgraph_costs); + MachineMapping machine_mapping = optimal_cost(g, + test_allowed_machine_views, + TestCostEstimator{}, + machine_spec, + cached_subgraph_costs); RC_ASSERT(machine_mapping.runtime > 0); RC_ASSERT(keys(machine_mapping.machine_views) == get_nodes(g)); }); From 70e2b49209865eb94d6e44642694648a3a97667b Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 9 Aug 2023 21:14:25 -0400 Subject: [PATCH 09/61] remove commited out codes --- lib/utils/include/utils/test_types.h | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h index 514d030b6f..6002d763a6 100644 --- a/lib/utils/include/utils/test_types.h +++ b/lib/utils/include/utils/test_types.h @@ -79,11 +79,7 @@ using hash_cmp = test_type_t; namespace std { -template < - ::FlexFlow::test_types:: - capability... CAPABILITIES> //, typename = typename - // std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, - // bool>::type> +template <::FlexFlow::test_types::capability... CAPABILITIES> struct hash<::FlexFlow::test_types::test_type_t> { typename std::enable_if< ::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE, From 6283528af178c0ac09b657618af2ab47402c921b Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 10 Aug 2023 21:18:55 -0400 Subject: [PATCH 10/61] view MultiDiGraph as labelled --- .../include/utils/graph/labelled/views.h | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index b0108c6e5d..b034454312 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -21,6 +21,41 @@ struct LabelledMultiDiSubgraphView std::unordered_set const &); }; +template +struct ViewMultiDiGraphAsOutputLabelled + : public IOutputLabelledMultiDiGraphView { +public: + ViewMultiDiGraphAsOutputLabelled() = delete; + explicit ViewMultiDiGraphAsOutputLabelled( + MultiDiGraphView const &g, + std::function const &node_label, + std::function const &output_label) + : g(g), node_label(node_label), output_label(output_label) {} + + virtual std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); + } + + virtual std::unordered_set query_edges(MultiDiEdgeQuery const &q) const override { + return g.query_edges(q); + } + + virtual NodeLabel const &at(Node const &n) const override { + return node_label(n); + } + + virtual OutputLabel &at(MultiDiOutput const &o) override { + return output_label(o); + } + +private: + MultiDiGraphView g; + std::function node_label; + std::function output_label; +}; + +CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled); + } // namespace FlexFlow #endif From a24009244ff7cce541a018aceb593165f07d5569 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 10 Aug 2023 22:48:57 -0400 Subject: [PATCH 11/61] make machine mapping immutable --- .../include/compiler/machine_mapping.h | 47 +++++-- .../include/compiler/unity_algorithm.h | 3 +- lib/compiler/src/machine_mapping.cc | 120 +++++++++--------- lib/compiler/src/unity_algorithm.cc | 22 ++-- lib/compiler/test/test_generator.h | 3 +- 5 files changed, 109 insertions(+), 86 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index aeec1362dd..c7beb2925c 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -5,33 +5,58 @@ #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph.h" +#include "sub_parallel_computation_graph.h" namespace FlexFlow { struct MachineMapping { - static MachineMapping sequential_combine(MachineMapping const &s1, - MachineMapping const &s2); - static MachineMapping parallel_combine(MachineMapping const &s1, - MachineMapping const &s2); - static MachineMapping infinity(); + static MachineMapping combine(MachineMapping const &, MachineMapping const &); - float runtime; req> machine_views; }; -FF_VISITABLE_STRUCT(MachineMapping, runtime, machine_views); +FF_VISITABLE_STRUCT(MachineMapping, machine_views); + +struct OptimalCostState { + SerialParallelDecomposition subgraph; + MachineSpecification resource; + req> source_machine_view, sink_machine_view; +}; +FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource, source_machine_view, sink_machine_view); + +struct OptimalCostResult { + static OptimalCostResult sequential_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2); + static OptimalCostResult parallel_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2); + static OptimalCostResult infinity(); + + float runtime; + MachineMapping machine_mapping; +}; +FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); + +struct OptimalCostRuntimeCmp { + bool operator()(OptimalCostResult const &, OptimalCostResult const &); +}; + +class OptimalCostCache { +public: + OptimalCostCache() = default; -struct MachineMappingRuntimeCmp { - bool operator()(MachineMapping const &, MachineMapping const &); + optional load(OptimalCostState const &) const; + void save(OptimalCostState const &, OptimalCostResult const &); +private: + std::unordered_map cache; }; -MachineMapping optimal_cost( +OptimalCostResult optimal_cost( ParallelComputationGraph const &g, std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views, CostEstimator const &cost_estimator, MachineSpecification const &resources, - std::unordered_map &cached_subgraph_costs); + OptimalCostCache &cached_subgraph_costs); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index e8c67c38ae..57f1c8c063 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -13,8 +13,9 @@ struct Substitution {}; struct Strategy { ParallelComputationGraph pcg; MachineMapping machine_mapping; + req runtime; }; -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping); +FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); struct StrategyRuntimeCmp { bool operator()(Strategy const &, Strategy const &); diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index dd54beda2c..945794d38c 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -7,28 +7,50 @@ namespace FlexFlow { -MachineMapping MachineMapping::sequential_combine(MachineMapping const &s1, - MachineMapping const &s2) { - return {s1.runtime + s2.runtime, - merge_maps(s1.machine_views, s2.machine_views)}; +MachineMapping MachineMapping::combine(MachineMapping const &s1, + MachineMapping const &s2) { + return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; } -MachineMapping MachineMapping::parallel_combine(MachineMapping const &s1, - MachineMapping const &s2) { - return {std::max(s1.runtime, s2.runtime), - merge_maps(s1.machine_views, s2.machine_views)}; +OptimalCostResult + OptimalCostResult::sequential_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2) { + return OptimalCostResult{ + s1.runtime + s2.runtime, + MachineMapping::combine(s1.machine_mapping, s2.machine_mapping)}; } -MachineMapping MachineMapping::infinity() { - return {std::numeric_limits::infinity(), - std::unordered_map{}}; +OptimalCostResult + OptimalCostResult::parallel_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2) { + return OptimalCostResult{ + std::max(s1.runtime, s2.runtime), + MachineMapping::combine(s1.machine_mapping, s2.machine_mapping)}; } -bool MachineMappingRuntimeCmp::operator()(MachineMapping const &lhs, - MachineMapping const &rhs) { +OptimalCostResult OptimalCostResult::infinity() { + return {std::numeric_limits::infinity(), MachineMapping{std::unordered_map{}}}; +} + +bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, + OptimalCostResult const &rhs) { return lhs.runtime < rhs.runtime; } +optional + OptimalCostCache::load(OptimalCostState const &state) const { + if (contains_key(cache, state)) { + return make_optional(cache.at(state)); + } + return nullopt; +} + +void OptimalCostCache::save(OptimalCostState const &state, + OptimalCostResult const &result) { + assert(!contains_key(cache, state)); + cache.emplace(state, result); +} + std::vector> get_resource_split(MachineSpecification const &resource) { std::vector> result; @@ -85,12 +107,12 @@ std::pair float estimate_cost( SubParallelComputationGraph const &g, CostEstimator const &estimator, - std::unordered_map const &device_mapping) { + MachineMapping const &device_mapping) { NOT_IMPLEMENTED(); } -void minimize_runtime(MachineMapping &m1, MachineMapping const &m2) { - minimize(m1, m2, MachineMappingRuntimeCmp{}); +void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { + minimize(m1, m2, OptimalCostRuntimeCmp{}); } struct OptimalCost { @@ -103,7 +125,7 @@ struct OptimalCost { std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views, - std::unordered_map &cached_subgraph_costs) + OptimalCostCache &cached_subgraph_costs) : g(g), cost_estimator(cost_estimator), resource(resource), source_machine_view(source_machine_view), sink_machine_view(sink_machine_view), @@ -118,47 +140,25 @@ struct OptimalCost { std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views; - std::unordered_map &cached_subgraph_costs; + OptimalCostCache &cached_subgraph_costs; - // TODO: move them out of the functor template - size_t hash_state(T const &sp_decomposition) const { - size_t h = std::hash{}(sp_decomposition); - hash_combine(h, resource); - hash_combine(h, source_machine_view); - hash_combine(h, sink_machine_view); - return h; - } - - optional load_result_from_cache(size_t hash_value) const { - if (contains_key(cached_subgraph_costs, hash_value)) { - return make_optional(cached_subgraph_costs.at(hash_value)); - } - return nullopt; - } + OptimalCostResult operator()(T const &t) const { + OptimalCostState state{g, resource, source_machine_view, sink_machine_view}; + optional cached_result = + cached_subgraph_costs.load(state); - void save_result_to_cache(size_t hash_value, - MachineMapping const &strategy) const { - assert(!contains_key(cached_subgraph_costs, hash_value)); - cached_subgraph_costs.emplace(hash_value, strategy); - } - - template - MachineMapping operator()(T const &t) const { - size_t state_hash_value = hash_state(t); - optional cached_result = - load_result_from_cache(state_hash_value); if (cached_result) { return cached_result.value(); } - MachineMapping result = this->optimal_cost(t); + OptimalCostResult result = this->optimal_cost(t); - save_result_to_cache(state_hash_value, result); + cached_subgraph_costs.save(state, result); return result; } - MachineMapping optimal_cost(Serial const &serial) const { + OptimalCostResult optimal_cost(Serial const &serial) const { auto decomposed = decompose(serial); SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; @@ -177,7 +177,7 @@ struct OptimalCost { Node const &split_point = get_only(set_union(pre_graph_sinks, post_graph_sources)); - MachineMapping optimal_result = MachineMapping::infinity(); + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { @@ -187,7 +187,7 @@ struct OptimalCost { contains(post_graph_sources, split_point) ? make_optional(mv) : nullopt; minimize_runtime(optimal_result, - MachineMapping::sequential_combine( + OptimalCostResult::sequential_combine( visit(OptimalCost(pre_graph, cost_estimator, resource, @@ -209,7 +209,7 @@ struct OptimalCost { return optimal_result; } - MachineMapping optimal_cost(Parallel const ¶llel) const { + OptimalCostResult optimal_cost(Parallel const ¶llel) const { auto decomposed = decompose(parallel); SerialParallelDecomposition decompn1 = decomposed.first; SerialParallelDecomposition decompn2 = decomposed.second; @@ -217,7 +217,7 @@ struct OptimalCost { auto subgraphs = apply_split(g, get_graph_split(decompn1, decompn2)); SubParallelComputationGraph g1 = subgraphs.first, g2 = subgraphs.second; - MachineMapping optimal_result = MachineMapping::sequential_combine( + OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( visit(OptimalCost(g1, cost_estimator, resource, @@ -237,7 +237,7 @@ struct OptimalCost { for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime(optimal_result, - MachineMapping::parallel_combine( + OptimalCostResult::parallel_combine( visit(OptimalCost(g1, cost_estimator, resource_split.first, @@ -259,25 +259,23 @@ struct OptimalCost { return optimal_result; } - MachineMapping optimal_cost(Node const &node) const { + OptimalCostResult optimal_cost(Node const &node) const { if (source_machine_view) { assert(get_closed_sources(g).empty()); assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); - std::unordered_map mv_map{ - {node, source_machine_view.value()}}; + MachineMapping mv_map{{{node, source_machine_view.value()}}}; return {estimate_cost(g, cost_estimator, mv_map), mv_map}; } else if (sink_machine_view) { assert(get_closed_sinks(g).empty()); assert(contains(allowed_machine_views(g.at(node), resource), sink_machine_view.value())); - std::unordered_map mv_map{ - {node, sink_machine_view.value()}}; + MachineMapping mv_map{{{node, sink_machine_view.value()}}}; return {estimate_cost(g, cost_estimator, mv_map), mv_map}; } else { - MachineMapping optimal_result = MachineMapping::infinity(); + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { - std::unordered_map mv_map{{node, mv}}; + MachineMapping mv_map{{{node, mv}}}; minimize_runtime(optimal_result, {estimate_cost(g, cost_estimator, mv_map), mv_map}); } @@ -286,14 +284,14 @@ struct OptimalCost { } }; -MachineMapping optimal_cost( +OptimalCostResult optimal_cost( ParallelComputationGraph const &g, std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views, CostEstimator const &cost_estimator, MachineSpecification const &resources, - std::unordered_map &cached_subgraph_costs) { + OptimalCostCache &cached_subgraph_costs) { return visit(OptimalCost(pcg_to_subpcg(g), cost_estimator, resources, diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ed62fd941d..f5747e2058 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -6,7 +6,7 @@ namespace FlexFlow { bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { - return lhs.machine_mapping.runtime < rhs.machine_mapping.runtime; + return lhs.runtime < rhs.runtime; } std::unordered_set @@ -29,7 +29,7 @@ Strategy std::unordered_set subs = get_all_substitutions(pcg); - std::unordered_map cached_subgraph_costs; + OptimalCostCache cached_subgraph_costs; DeduplicatedPriorityQueue, StrategyRuntimeCmp> candidates; @@ -50,20 +50,20 @@ Strategy if (StrategyRuntimeCmp(current_result, best_result)) { best_result = current_result; - } else if (current_result.machine_mapping.runtime > - best_result.machine_mapping.runtime * opt_config.alpha) { + } else if (current_result.runtime > + best_result.runtime * opt_config.alpha) { continue; } for (auto const &sub : subs) { for (auto const &new_pcg : apply_substitution(current_result.pcg, sub)) { - Strategy new_result(new_pcg, - optimal_cost(new_pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs)); - if (new_result.machine_mapping.runtime <= opt_config.threshold && + OptimalCostResult c = optimal_cost(new_pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); + Strategy new_result(new_pcg, c.machine_mapping, c.runtime); + if (new_result.runtime <= opt_config.threshold && new_result.pcg.query_nodes({}).size() <= opt_config.max_num_ops) { candidates.push(new_result); } diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index c258ad46d0..9e1844cc04 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -189,7 +189,6 @@ template <> struct Arbitrary { static Gen arbitrary() { return gen::build( - gen::set(&MachineMapping::runtime, gen::nonZero()); gen::set(&MachineMapping::machine_views, gen::container>( gen::arbitrary(), gen::arbitrary()))); @@ -199,7 +198,7 @@ struct Arbitrary { template <> struct Arbitrary { static Gen arbitrary() { - return gen::build( + return gen::build( gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), From a8988d9398f354b3a0b27a5b0755d70328cfb921 Mon Sep 17 00:00:00 2001 From: wmdi Date: Fri, 11 Aug 2023 09:01:35 -0400 Subject: [PATCH 12/61] minor fix & format --- .../include/compiler/machine_mapping.h | 23 +++++++----- lib/compiler/src/machine_mapping.cc | 26 +++++++------- lib/compiler/test/test_machine_mapping.cc | 35 ++++++------------- lib/compiler/test/test_optimal_cost.cc | 16 ++++----- .../include/utils/graph/labelled/views.h | 9 +++-- 5 files changed, 51 insertions(+), 58 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index c7beb2925c..9d872fead2 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -21,7 +21,11 @@ struct OptimalCostState { MachineSpecification resource; req> source_machine_view, sink_machine_view; }; -FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource, source_machine_view, sink_machine_view); +FF_VISITABLE_STRUCT(OptimalCostState, + subgraph, + resource, + source_machine_view, + sink_machine_view); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, @@ -45,18 +49,19 @@ class OptimalCostCache { optional load(OptimalCostState const &) const; void save(OptimalCostState const &, OptimalCostResult const &); + private: std::unordered_map cache; }; -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs); +OptimalCostResult + optimal_cost(ParallelComputationGraph const &g, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + OptimalCostCache &cached_subgraph_costs); } // namespace FlexFlow diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 945794d38c..5f0df5a5a4 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -29,7 +29,8 @@ OptimalCostResult } OptimalCostResult OptimalCostResult::infinity() { - return {std::numeric_limits::infinity(), MachineMapping{std::unordered_map{}}}; + return {std::numeric_limits::infinity(), + MachineMapping{std::unordered_map{}}}; } bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, @@ -104,10 +105,9 @@ std::pair } } -float estimate_cost( - SubParallelComputationGraph const &g, - CostEstimator const &estimator, - MachineMapping const &device_mapping) { +float estimate_cost(SubParallelComputationGraph const &g, + CostEstimator const &estimator, + MachineMapping const &device_mapping) { NOT_IMPLEMENTED(); } @@ -284,14 +284,14 @@ struct OptimalCost { } }; -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs) { +OptimalCostResult + optimal_cost(ParallelComputationGraph const &g, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + OptimalCostCache &cached_subgraph_costs) { return visit(OptimalCost(pcg_to_subpcg(g), cost_estimator, resources, diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/test_machine_mapping.cc index fe72d549cc..351eabb491 100644 --- a/lib/compiler/test/test_machine_mapping.cc +++ b/lib/compiler/test/test_machine_mapping.cc @@ -5,36 +5,21 @@ bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); } -TEST_CASE("MachineMapping::sequential_combine") { - rc::check([](MachineMapping const &mp0, MachineMapping const &mp1) { - RC_PRE(nodes_are_disjoint(mp0, mp1)); +TEST_CASE("MachineMapping::combine") { + rc::check([](MachineMapping const &m0, MachineMapping const &m1) { + RC_PRE(nodes_are_disjoint(m0, m1)); - MachineMapping comb = MachineMapping::sequential_combine(mp0, mp1); + MachineMapping comb = MachineMapping::combine(m0, m1); - RC_ASSERT(comb.runtime == mp0.runtime + mp1.runtime); RC_ASSERT(comb.machine_views.size() == - mp0.machine_views.size() + mp1.machine_views.size()); - RC_ASSERT(is_submap(comb.machine_views, mp0.machine_views)); - RC_ASSERT(is_submap(comb.machine_views, mp1.machine_views)); + m0.machine_views.size() + m1.machine_views.size()); + RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); + RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); }); } -TEST_CASE("MachineMapping::parallel_combine") { - rc::check([](MachineMapping const &mp0, MachineMapping const &mp1) { - RC_PRE(nodes_are_disjoint(mp0, mp1)); - - MachineMapping comb = MachineMapping::parallel_combine(mp0, mp1); - - RC_ASSERT(comb.runtime == std::max(mp0.runtime, mp1.runtime)); - RC_ASSERT(comb.machine_views.size() == - mp0.machine_views.size() + mp1.machine_views.size()); - RC_ASSERT(is_submap(comb.machine_views, mp0.machine_views)); - RC_ASSERT(is_submap(comb.machine_views, mp1.machine_views)); - }); -} - -TEST_CASE("MachineMapping::infinity") { - rc::check([](MachineMapping const &mp) { - RC_ASSERT(mp.runtime <= MachineMapping::infinity().runtime); +TEST_CASE("OptimalCostResult::infinity") { + rc::check([](OptimalCostResult const &c) { + RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); }); } \ No newline at end of file diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc index 8993eb4228..0bb010068e 100644 --- a/lib/compiler/test/test_optimal_cost.cc +++ b/lib/compiler/test/test_optimal_cost.cc @@ -12,13 +12,13 @@ TEST_CASE("optimal_cost") { }; rc::check([](ParallelComputationGraph const &g, MachineSpecification const &machine_spec) { - std::unordered_map cached_subgraph_costs; - MachineMapping machine_mapping = optimal_cost(g, - test_allowed_machine_views, - TestCostEstimator{}, - machine_spec, - cached_subgraph_costs); - RC_ASSERT(machine_mapping.runtime > 0); - RC_ASSERT(keys(machine_mapping.machine_views) == get_nodes(g)); + OptimalCostCache cached_subgraph_costs; + OptimalCostResult result = optimal_cost(g, + test_allowed_machine_views, + TestCostEstimator{}, + machine_spec, + cached_subgraph_costs); + RC_ASSERT(result.runtime > 0); + RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); }); } \ No newline at end of file diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index b034454312..1fe999bdfb 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -32,11 +32,13 @@ struct ViewMultiDiGraphAsOutputLabelled std::function const &output_label) : g(g), node_label(node_label), output_label(output_label) {} - virtual std::unordered_set query_nodes(NodeQuery const &q) const override { + virtual std::unordered_set + query_nodes(NodeQuery const &q) const override { return g.query_nodes(q); } - virtual std::unordered_set query_edges(MultiDiEdgeQuery const &q) const override { + virtual std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const override { return g.query_edges(q); } @@ -54,7 +56,8 @@ struct ViewMultiDiGraphAsOutputLabelled std::function output_label; }; -CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled); +CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled); } // namespace FlexFlow From e6bc14a8b2e18ac95843bdbddf9fff6dd7ff2879 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 14 Aug 2023 10:41:28 -0400 Subject: [PATCH 13/61] move general codes into proper places --- .../include/compiler/machine_mapping.h | 4 +- lib/compiler/src/machine_mapping.cc | 4 + lib/compiler/test/test_generator.h | 90 ------------------- lib/compiler/test/test_machine_mapping.cc | 6 +- .../include/utils/graph/serialparallel.h | 14 +++ lib/utils/src/graph/serialparallel.cc | 90 +++++++++++++++++++ 6 files changed, 112 insertions(+), 96 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 9d872fead2..b1a9b7f384 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -11,11 +11,13 @@ namespace FlexFlow { struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); - + static bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); + req> machine_views; }; FF_VISITABLE_STRUCT(MachineMapping, machine_views); + struct OptimalCostState { SerialParallelDecomposition subgraph; MachineSpecification resource; diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 5f0df5a5a4..55b6e38cfa 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -12,6 +12,10 @@ MachineMapping MachineMapping::combine(MachineMapping const &s1, return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; } +bool MachineMapping::nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { + return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); +} + OptimalCostResult OptimalCostResult::sequential_combine(OptimalCostResult const &s1, OptimalCostResult const &s2) { diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index 9e1844cc04..52dce39ad4 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -7,96 +7,6 @@ using namespace FlexFlow; -std::unordered_map parallel_extend(MultiDiGraph &g, - MultiDiGraph const &ext) { - std::unordered_map node_map; - std::unordered_map node_port_map; - for (Node const &node : get_nodes(ext)) { - node_map[node] = g.add_node(); - } - for (NodePort const &node_port : get_node_ports(ext)) { - node_port_map[node_port] = g.add_node_port(); - } - for (MultiDiEdge const &edge : get_edges(ext)) { - g.add_edge(MultiDiEdge{node_map[edge.src], - node_map[edge.dst], - node_map[edge.srcIdx], - node_map[edge.dstIdx]}); - } - return node_map; -} - -std::unordered_map serial_extend(MultiDiGraph &g, - MultiDiGraph const &ext) { - std::unordered_set original_sinks = get_sinks(g); - std::unordered_map node_map = parallel_extend(g, ext); - for (Node const &node1 : original_sinks) { - for (Node const &node2 : get_sources(ext)) { - g.add_edge(MultiDiEdge{ - node1, node_map[node2], g.add_node_port(), g.add_node_port()}); - } - } - return node_map; -} - -MultiDiGraph serial_composition(MultiDiGraph const &g1, - MultiDiGraph const &g2) { - MultiDiGraph g = g1; - serial_extend(g, g2); - return g; -} - -MultiDiGraph parallel_composition(MultiDiGraph const &g1, - MultiDiGraph const &g2) { - MultiDiGraph g = g1; - parallel_extend(g, g2); - return g; -} - -struct MultiDiGraphFromSPDecomposition { - template - MultiDiGraph operator()(T const &t) { - return multidigraph_from_sp_decomposition(t); - } -}; - -MultiDiGraph multidigraph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); -} - -MultiDiGraph multidigraph_from_sp_decomposition( - variant const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); -} - -MultiDiGraph multidigraph_from_sp_decomposition( - variant const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); -} - -MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { - MultiDiGraph g = MultiDiGraph::create(); - for (auto child : serial.children) { - serial_extend(g, multidigraph_from_sp_decomposition(child)); - } - return g; -} - -MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { - MultiDiGraph g = MultiDiGraph::create(); - for (auto child : parallel.children) { - parallel_extend(g, multidigraph_from_sp_decomposition(child)); - } - return g; -} - -MultiDiGraph multidigraph_from_sp_decomposition(Node const &Node) { - MultiDiGraph g = MultiDiGraph::create(); - g.add_node(); - return g; -} - template OutputLabelledMultiDiGraph generate_test_labelled_sp_graph() { diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/test_machine_mapping.cc index 351eabb491..940592e0e6 100644 --- a/lib/compiler/test/test_machine_mapping.cc +++ b/lib/compiler/test/test_machine_mapping.cc @@ -1,13 +1,9 @@ #include "doctest.h" #include "test_generator.h" -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { - return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); -} - TEST_CASE("MachineMapping::combine") { rc::check([](MachineMapping const &m0, MachineMapping const &m1) { - RC_PRE(nodes_are_disjoint(m0, m1)); + RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); MachineMapping comb = MachineMapping::combine(m0, m1); diff --git a/lib/utils/include/utils/graph/serialparallel.h b/lib/utils/include/utils/graph/serialparallel.h index e7d8bd2ca9..d68cebac7a 100644 --- a/lib/utils/include/utils/graph/serialparallel.h +++ b/lib/utils/include/utils/graph/serialparallel.h @@ -35,6 +35,20 @@ SerialParallelDecomposition std::unordered_set get_nodes(SerialParallelDecomposition const &sp); +std::unordered_map parallel_extend(MultiDiGraph &g, + MultiDiGraph const &ext); + +std::unordered_map serial_extend(MultiDiGraph &g, + MultiDiGraph const &ext); + +MultiDiGraph serial_composition(MultiDiGraph const &g1, MultiDiGraph const &g2); + +MultiDiGraph parallel_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2); + +MultiDiGraph multidigraph_from_sp_decomposition( + SerialParallelDecomposition const &sp_decomposition); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 8a034ad809..06760b5f67 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -217,4 +217,94 @@ std::unordered_set get_nodes(Node const &node) { return {node}; } +std::unordered_map parallel_extend(MultiDiGraph &g, + MultiDiGraph const &ext) { + std::unordered_map node_map; + std::unordered_map node_port_map; + for (Node const &node : get_nodes(ext)) { + node_map[node] = g.add_node(); + } + for (NodePort const &node_port : get_node_ports(ext)) { + node_port_map[node_port] = g.add_node_port(); + } + for (MultiDiEdge const &edge : get_edges(ext)) { + g.add_edge(MultiDiEdge{node_map[edge.src], + node_map[edge.dst], + node_map[edge.srcIdx], + node_map[edge.dstIdx]}); + } + return node_map; +} + +std::unordered_map serial_extend(MultiDiGraph &g, + MultiDiGraph const &ext) { + std::unordered_set original_sinks = get_sinks(g); + std::unordered_map node_map = parallel_extend(g, ext); + for (Node const &node1 : original_sinks) { + for (Node const &node2 : get_sources(ext)) { + g.add_edge(MultiDiEdge{ + node1, node_map[node2], g.add_node_port(), g.add_node_port()}); + } + } + return node_map; +} + +MultiDiGraph serial_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2) { + MultiDiGraph g = g1; + serial_extend(g, g2); + return g; +} + +MultiDiGraph parallel_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2) { + MultiDiGraph g = g1; + parallel_extend(g, g2); + return g; +} + +struct MultiDiGraphFromSPDecomposition { + template + MultiDiGraph operator()(T const &t) { + return multidigraph_from_sp_decomposition(t); + } +}; + +MultiDiGraph multidigraph_from_sp_decomposition( + SerialParallelDecomposition const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition( + variant const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition( + variant const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { + MultiDiGraph g = MultiDiGraph::create(); + for (auto child : serial.children) { + serial_extend(g, multidigraph_from_sp_decomposition(child)); + } + return g; +} + +MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { + MultiDiGraph g = MultiDiGraph::create(); + for (auto child : parallel.children) { + parallel_extend(g, multidigraph_from_sp_decomposition(child)); + } + return g; +} + +MultiDiGraph multidigraph_from_sp_decomposition(Node const &Node) { + MultiDiGraph g = MultiDiGraph::create(); + g.add_node(); + return g; +} + } // namespace FlexFlow From 04a9525d7e6e332791ad08ea776cd95498fe3c71 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 14 Aug 2023 10:41:54 -0400 Subject: [PATCH 14/61] format --- lib/compiler/include/compiler/machine_mapping.h | 6 +++--- lib/compiler/src/machine_mapping.cc | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index b1a9b7f384..4089260735 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -11,13 +11,13 @@ namespace FlexFlow { struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); - static bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - + static bool nodes_are_disjoint(MachineMapping const &m1, + MachineMapping const &m2); + req> machine_views; }; FF_VISITABLE_STRUCT(MachineMapping, machine_views); - struct OptimalCostState { SerialParallelDecomposition subgraph; MachineSpecification resource; diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 55b6e38cfa..2f6af8a62b 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -12,7 +12,8 @@ MachineMapping MachineMapping::combine(MachineMapping const &s1, return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; } -bool MachineMapping::nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { +bool MachineMapping::nodes_are_disjoint(MachineMapping const &m1, + MachineMapping const &m2) { return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); } From 60b6f59e5e357d23ba38325403eaa57d590e0aa0 Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 16 Aug 2023 22:55:34 -0400 Subject: [PATCH 15/61] minor fix & format --- lib/compiler/test/test_machine_mapping.cc | 2 +- lib/compiler/test/test_optimal_cost.cc | 2 +- lib/compiler/test/test_unity_algorithm.cc | 2 +- lib/utils/src/graph/serialparallel.cc | 17 +++++++++-------- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/test_machine_mapping.cc index 940592e0e6..4436a992d3 100644 --- a/lib/compiler/test/test_machine_mapping.cc +++ b/lib/compiler/test/test_machine_mapping.cc @@ -18,4 +18,4 @@ TEST_CASE("OptimalCostResult::infinity") { rc::check([](OptimalCostResult const &c) { RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); }); -} \ No newline at end of file +} diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc index 0bb010068e..2d9414ba27 100644 --- a/lib/compiler/test/test_optimal_cost.cc +++ b/lib/compiler/test/test_optimal_cost.cc @@ -21,4 +21,4 @@ TEST_CASE("optimal_cost") { RC_ASSERT(result.runtime > 0); RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); }); -} \ No newline at end of file +} diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/test_unity_algorithm.cc index 8be65eed94..6a0131dd77 100644 --- a/lib/compiler/test/test_unity_algorithm.cc +++ b/lib/compiler/test/test_unity_algorithm.cc @@ -20,4 +20,4 @@ TEST_CASE("graph_optimize") { RC_ASSERT(s.machine_mapping.runtime > 0); RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); }); -} \ No newline at end of file +} diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 06760b5f67..a5cdd44f12 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -1,6 +1,7 @@ #include "utils/graph/serialparallel.h" #include "serialparallel_internal.h" #include "utils/containers.h" +#include "utils/graph/adjacency_multidigraph.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph.h" @@ -221,17 +222,17 @@ std::unordered_map parallel_extend(MultiDiGraph &g, MultiDiGraph const &ext) { std::unordered_map node_map; std::unordered_map node_port_map; - for (Node const &node : get_nodes(ext)) { - node_map[node] = g.add_node(); + for (Node const &node : get_nodes(MultiDiGraphView(ext))) { + node_map.emplace(node, g.add_node()); } for (NodePort const &node_port : get_node_ports(ext)) { - node_port_map[node_port] = g.add_node_port(); + node_port_map.emplace(node_port, g.add_node_port()); } for (MultiDiEdge const &edge : get_edges(ext)) { - g.add_edge(MultiDiEdge{node_map[edge.src], - node_map[edge.dst], - node_map[edge.srcIdx], - node_map[edge.dstIdx]}); + g.add_edge(MultiDiEdge{node_map.at(edge.src), + node_map.at(edge.dst), + node_port_map.at(edge.srcIdx), + node_port_map.at(edge.dstIdx)}); } return node_map; } @@ -243,7 +244,7 @@ std::unordered_map serial_extend(MultiDiGraph &g, for (Node const &node1 : original_sinks) { for (Node const &node2 : get_sources(ext)) { g.add_edge(MultiDiEdge{ - node1, node_map[node2], g.add_node_port(), g.add_node_port()}); + node1, node_map.at(node2), g.add_node_port(), g.add_node_port()}); } } return node_map; From 8fd7ef0923c31a955511eb72f7f24e12676af0a1 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 17 Aug 2023 21:30:47 -0400 Subject: [PATCH 16/61] minor fix --- lib/compiler/test/test_generator.h | 49 ++++++++++++++++--- .../include/utils/graph/labelled/views.h | 17 +++++++ lib/utils/src/graph/serialparallel.cc | 12 ++--- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index 52dce39ad4..b3453b014c 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -7,12 +7,34 @@ using namespace FlexFlow; -template -OutputLabelledMultiDiGraph - generate_test_labelled_sp_graph() { - NOT_IMPLEMENTED(); - // Is there a way to construct a labelled graph from a MultiDiGraph and the - // labels? +/* + Generates computation graphs with trivial layers and tensors, which are used + for tests focusing on graph structures. +*/ +ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { + return materialize_output_labelled_multidigraph_view( + ViewMultiDiGraphAsOutputLabelled( + g, + [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, + [](Tensor(MultiDiOutput const &)) { + return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; + })); +} + +/* + Generates parallel computation graphs with trivial layers and tensors, which + are used for tests focusing on graph structures. +*/ +ParallelComputationGraph + test_parallel_computation_graph(MultiDiGraphView const &g) { + return materialize_output_labelled_multidigraph_view( + ViewMultiDiGraphAsOutputLabelled( + g, + [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, + [](Operator(MultiDiOutput const &)) { + return ParallelTensor(ParallelTensorDims(TensorDims({})), + DataType::FLOAT); + })); } rc::Gen small_integer_generator() { @@ -26,6 +48,21 @@ Gen serialParallelMultiDiGraph() { multidigraph_from_sp_decomposition); } +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::map(serialParallelMultiDiGraph, test_computataion_graph); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::map(serialParallelMultiDiGraph, + test_parallel_computataion_graph); + } +}; + template <> struct Arbitrary> { static Gen> arbitrary() { diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 1fe999bdfb..85b5d3ef5c 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -59,6 +59,23 @@ struct ViewMultiDiGraphAsOutputLabelled CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled); +template +Impl materialize_output_labelled_multidigraph_view( + IOutputLabelledMultiDiGraphView const &g) { + Impl result; + for (Node const &n : get_nodes(g)) { + result.add_node_unsafe(n); + result.at(n) = g.at(n); + } + for (auto const &e : get_edges(g)) { + result.add_edge(e); + } + for (MultiDiOutput const &o : get_outputs(g)) { + result.add_output(o, g.at(o)); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index a5cdd44f12..5484171a20 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -264,7 +264,7 @@ MultiDiGraph parallel_composition(MultiDiGraph const &g1, return g; } -struct MultiDiGraphFromSPDecomposition { +struct MultiDiGraphFromSPDecompositionFunctor { template MultiDiGraph operator()(T const &t) { return multidigraph_from_sp_decomposition(t); @@ -273,22 +273,22 @@ struct MultiDiGraphFromSPDecomposition { MultiDiGraph multidigraph_from_sp_decomposition( SerialParallelDecomposition const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); + return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); } MultiDiGraph multidigraph_from_sp_decomposition( variant const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); + return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); } MultiDiGraph multidigraph_from_sp_decomposition( variant const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecomposition{}, sp_decomposition); + return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); } MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { MultiDiGraph g = MultiDiGraph::create(); - for (auto child : serial.children) { + for (variant const &child : serial.children) { serial_extend(g, multidigraph_from_sp_decomposition(child)); } return g; @@ -296,7 +296,7 @@ MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { MultiDiGraph g = MultiDiGraph::create(); - for (auto child : parallel.children) { + for (variant const &child : parallel.children) { parallel_extend(g, multidigraph_from_sp_decomposition(child)); } return g; From 60e3945af977b28dd1ff10a8fe9e39811e787a51 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sat, 19 Aug 2023 15:36:42 -0400 Subject: [PATCH 17/61] update substitutions to align with latest changes --- lib/compiler/src/unity_algorithm.cc | 2 +- .../include/substitutions/get_attribute.h | 128 +++++------ .../include/substitutions/graph_pattern.h | 30 +-- .../include/substitutions/substitutions.h | 162 +++++++------ .../include/substitutions/substitutions_old.h | 101 ++++++++ .../include/substitutions/substitutions_v2.h | 115 ---------- lib/substitutions/src/graph_pattern.cc | 70 ++---- lib/substitutions/src/substitutions.cc | 215 ++++++++++++++++++ lib/utils/include/utils/graph/algorithms.h | 2 + lib/utils/src/graph/algorithms.cc | 38 ++++ 10 files changed, 537 insertions(+), 326 deletions(-) create mode 100644 lib/substitutions/include/substitutions/substitutions_old.h delete mode 100644 lib/substitutions/include/substitutions/substitutions_v2.h create mode 100644 lib/substitutions/src/substitutions.cc diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index f5747e2058..ef093fc11e 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -14,7 +14,7 @@ std::unordered_set std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, - Substitution const &); + Substitution const &) {} Strategy graph_optimize(ComputationGraph &cg, diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index 55068a8c62..0a4d5b99fd 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -2,76 +2,74 @@ #define _FLEXFLOW_SUBSTITUTIONS_OPERATOR_ATTRIBUTES_H #include "op-attrs/operator_attrs.h" -#include "substitutions/substitutions_v2.h" +#include "substitutions/substitutions.h" #include "tl/optional.hpp" namespace FlexFlow { -namespace substitutions { -tl::optional get_attribute(PCGOperatorAttrs const &, - OperatorAttributeKey); -tl::optional get_attribute(AggregateAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(AggregateSpecAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(CastAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(Group_byAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey); -tl::optional - get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); -tl::optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey); -tl::optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey); -tl::optional - get_attribute(FusedParallelOpAttrs const &p, OperatorAttributeKey); +optional get_attribute(PCGOperatorAttrs const &, + OperatorAttributeKey); +optional get_attribute(AggregateAttrs const &p, + OperatorAttributeKey); +optional get_attribute(AggregateSpecAttrs const &p, + OperatorAttributeKey); +optional get_attribute(BatchMatmulAttrs const &p, + OperatorAttributeKey); +optional get_attribute(CastAttrs const &p, + OperatorAttributeKey); +optional get_attribute(CombineAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ConcatAttrs const &p, + OperatorAttributeKey); +optional get_attribute(Conv2DAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ElementBinaryAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ElementUnaryAttrs const &p, + OperatorAttributeKey); +optional get_attribute(DropoutAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ElementBinaryAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ElementUnaryAttrs const &p, + OperatorAttributeKey); +optional get_attribute(EmbeddingAttrs const &p, + OperatorAttributeKey); +optional get_attribute(FlatAttrs const &p, + OperatorAttributeKey); +optional get_attribute(GatherAttrs const &p, + OperatorAttributeKey); +optional get_attribute(Group_byAttrs const &p, + OperatorAttributeKey); +optional get_attribute(LayerNormAttrs const &p, + OperatorAttributeKey); +optional get_attribute(LinearAttrs const &p, + OperatorAttributeKey); +optional get_attribute(MultiHeadAttentionAttrs const &p, + OperatorAttributeKey); +optional get_attribute(Pool2DAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ReduceAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ReductionAttrs const &p, + OperatorAttributeKey); +optional get_attribute(RepartitionAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ReplicateAttrs const &p, + OperatorAttributeKey); +optional get_attribute(ReshapeAttrs const &p, + OperatorAttributeKey); +optional get_attribute(SplitAttrs const &p, + OperatorAttributeKey); +optional get_attribute(SoftmaxAttrs const &p, + OperatorAttributeKey); +optional get_attribute(TopKAttrs const &p, + OperatorAttributeKey); +optional get_attribute(TransposeAttrs const &p, + OperatorAttributeKey); +optional get_attribute(FusedParallelOpAttrs const &p, + OperatorAttributeKey); -} // namespace substitutions } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 70e67c6ad1..d7654e19ce 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -4,36 +4,30 @@ #include "utils/graph.h" namespace FlexFlow { -namespace substitutions { struct DiGraphPatternMatch { bidict nodeAssignment; - bidict edgeAssignment; + req> edgeAssignment; }; +FF_VISITABLE_STRUCT(DiGraphPatternMatch, nodeAssignment, edgeAssignment); + struct MatchSplit { DiGraphPatternMatch prefix_submatch; - DiGraphPatternMatch postfix_submatch; + req postfix_submatch; }; -GraphSplit split_pattern(IOpenMultiDiGraph const &pattern); +FF_VISITABLE_STRUCT(MatchSplit, prefix_submatch, postfix_submatch); -bool pattern_matches(IOpenMultiDiGraphView const &, - IMultiDiGraph const &, - DiGraphPatternMatch const &); -bool is_singleton_pattern(IOpenMultiDiGraphView const &); +GraphSplit split_pattern(OpenMultiDiGraphView const &pattern); -} // namespace substitutions -} // namespace FlexFlow +bool pattern_matches(OpenMultiDiGraphView const &, + MultiDiGraphView const &, + DiGraphPatternMatch const &, + F const &additional_criterion); -namespace std { +bool is_singleton_pattern(OpenMultiDiGraphView const &); -template <> -struct hash<::FlexFlow::substitutions::DiGraphPatternMatch> { - size_t - operator()(::FlexFlow::substitutions::DiGraphPatternMatch const &) const; -}; - -} // namespace std +} // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/substitutions.h b/lib/substitutions/include/substitutions/substitutions.h index 33ffc9704d..528f93b355 100644 --- a/lib/substitutions/include/substitutions/substitutions.h +++ b/lib/substitutions/include/substitutions/substitutions.h @@ -1,101 +1,117 @@ -#ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H -#define _FLEXFLOW_SUBSTITUTION_LOADER_H +#ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_V2_H +#define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_V2_H -#include "op-meta/op-meta.h" -#include "tl/optional.hpp" -#include -#include +#include "graph_pattern.h" +#include "mpark/variant.hpp" +#include "utils/bidict.h" +#include "utils/graph.h" namespace FlexFlow { -namespace substitutions { - -enum class ParameterAttribute { - OP_TYPE, // AnyOp - NUM_INPUTS, // AnyOp - NUM_OUTPUTS, // AnyOp - GROUP, // Conv2D - KERNEL_H, // Conv2D, Pool2D - KERNEL_W, // Conv2D, Pool2D - STRIDE_H, // Conv2D, Pool2D - STRIDE_W, // Conv2D, Pool2D - PADDING_H, // Conv2D, Pool2D - PADDING_W, // Conv2D, Pool2D - ACTIVATION, // Conv2D, Pool2D - NUMDIM, // Concat, Transpose - AXIS, // Concat, Split - PERM, // Transpose - OUTSHUFFLE, // Transpose - MERGE_GCONV_COUNT, // MergeGConv - AXES, // Squeeze, Unsqueeze, Reduce* - KEEP_DIMS, // Reduce* - EPSILON, // BatchNorm - REPARTITION_DIM, // Repartition - REPARTITION_DEGREE, // Repartition - REPLICATE_DIM, // Replicate - REPLICATE_DEGREE, // Replicate - COMBINE_DIM, // Combine - COMBINE_DEGREE, // Combine - REDUCTION_DIM, // Reduction - REDUCTION_DEGREE, // Reduction - SOFTMAX_DIM, // Softmax - NUM_HEADS, // MultiHeadAttention - INVALID, + +enum class ConstraintType { EQUAL }; + +enum class OperatorAttributeKey { + OP_TYPE, // AnyOp + USE_BIAS, + GROUPS, + POOL_TYPE, + KERNEL_H, + KERNEL_W, + DATA_TYPE, + SCALAR, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + AGGR_MODE, + NUM_ENTRIES, + OUT_CHANNELS, + ACTIVATION, + NUMDIM, + AXIS, + PERMUTATION, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + PARALLEL_OP_DIM, + PARALLEL_OP_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, PARALLEL_DIM, PARALLEL_DEGREE, PAD, }; -enum class ConstraintType { - Equal, - NotEqual, - LessThan, - LessThanEqual, - GreaterThan, - GreaterThanEqual, +template +struct ListIndexAccess { + T attribute_key; + int index; }; -struct OperatorAttributeConstraint { - ParameterAttribute key; - ConstraintType constraint; - int value; +template +struct ListSize { + T attribute_key; }; -struct TensorConstraint {}; +template +using AttributeExpr = variant, ListSize>; -struct Tensor { - int opId; - int tsId; +enum class TensorDimensionAttribute { SIZE, DEGREE }; - std::vector constraints; +struct TensorNumDimensionsConstraint { + int value; +}; +struct TensorDimensionAttributeConstraint { + TensorDimensionAttribute attribute; + int index; }; -struct OperatorConstraint { - OperatorType op_type; - std::vector inputs; - std::vector constraints; +enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; - tl::optional at(ParameterAttribute key) const; +using OperatorAttributeValue = variant>; +using TensorAttributeValue = variant>; + +template +struct AttributeConstraint { + ConstraintType constraint_type; + AttributeExpr attribute_expr; + V attribute_value; }; -struct MapOutput { - int dstOpId; - int dstTsId; - int srcOpId; - int srcTsId; +using TensorAttributeConstraint = + AttributeConstraint; +using OperatorAttributeConstraint = + AttributeConstraint; + +struct OperatorPattern { + std::unordered_set attribute_constraints; }; -struct Substitution { - std::string name; - std::vector srcOp; - std::vector dstOp; - std::vector mappedOutput; +struct ParallelTensorPattern { + std::unordered_set attribute_constraints; }; -struct SubstitutionCollection { - std::vector substitutions; +struct SubstitutionPattern + : public strong_typedef< + SubstitutionPattern, + LabelledOpenMultiDiGraph> { + using strong_typedef::strong_typedef; }; -} // namespace substitutions +// struct SubstitutionPattern { +// OperatorPattern at(utils::Node) const; +// ParallelTensorPattern at(PatternEdge) const; + +// MultiDiGraphPattern graph; +// utils::bidict node_map; +// utils::bidict edge_map; +// }; + +bool assignment_satisfies(SubstitutionPattern const &, + DiGraphPatternMatch const &); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/substitutions_old.h b/lib/substitutions/include/substitutions/substitutions_old.h new file mode 100644 index 0000000000..33ffc9704d --- /dev/null +++ b/lib/substitutions/include/substitutions/substitutions_old.h @@ -0,0 +1,101 @@ +#ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H +#define _FLEXFLOW_SUBSTITUTION_LOADER_H + +#include "op-meta/op-meta.h" +#include "tl/optional.hpp" +#include +#include + +namespace FlexFlow { +namespace substitutions { + +enum class ParameterAttribute { + OP_TYPE, // AnyOp + NUM_INPUTS, // AnyOp + NUM_OUTPUTS, // AnyOp + GROUP, // Conv2D + KERNEL_H, // Conv2D, Pool2D + KERNEL_W, // Conv2D, Pool2D + STRIDE_H, // Conv2D, Pool2D + STRIDE_W, // Conv2D, Pool2D + PADDING_H, // Conv2D, Pool2D + PADDING_W, // Conv2D, Pool2D + ACTIVATION, // Conv2D, Pool2D + NUMDIM, // Concat, Transpose + AXIS, // Concat, Split + PERM, // Transpose + OUTSHUFFLE, // Transpose + MERGE_GCONV_COUNT, // MergeGConv + AXES, // Squeeze, Unsqueeze, Reduce* + KEEP_DIMS, // Reduce* + EPSILON, // BatchNorm + REPARTITION_DIM, // Repartition + REPARTITION_DEGREE, // Repartition + REPLICATE_DIM, // Replicate + REPLICATE_DEGREE, // Replicate + COMBINE_DIM, // Combine + COMBINE_DEGREE, // Combine + REDUCTION_DIM, // Reduction + REDUCTION_DEGREE, // Reduction + SOFTMAX_DIM, // Softmax + NUM_HEADS, // MultiHeadAttention + INVALID, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD, +}; + +enum class ConstraintType { + Equal, + NotEqual, + LessThan, + LessThanEqual, + GreaterThan, + GreaterThanEqual, +}; + +struct OperatorAttributeConstraint { + ParameterAttribute key; + ConstraintType constraint; + int value; +}; + +struct TensorConstraint {}; + +struct Tensor { + int opId; + int tsId; + + std::vector constraints; +}; + +struct OperatorConstraint { + OperatorType op_type; + std::vector inputs; + std::vector constraints; + + tl::optional at(ParameterAttribute key) const; +}; + +struct MapOutput { + int dstOpId; + int dstTsId; + int srcOpId; + int srcTsId; +}; + +struct Substitution { + std::string name; + std::vector srcOp; + std::vector dstOp; + std::vector mappedOutput; +}; + +struct SubstitutionCollection { + std::vector substitutions; +}; + +} // namespace substitutions +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/substitutions_v2.h b/lib/substitutions/include/substitutions/substitutions_v2.h deleted file mode 100644 index 77331c19ec..0000000000 --- a/lib/substitutions/include/substitutions/substitutions_v2.h +++ /dev/null @@ -1,115 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_V2_H -#define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_V2_H - -#include "graph_pattern.h" -#include "mpark/variant.hpp" -#include "utils/bidict.h" -#include "utils/graph.h" - -namespace FlexFlow { -namespace substitutions { - -enum class ConstraintType { EQUAL }; - -enum class OperatorAttributeKey { - OP_TYPE, // AnyOp - USE_BIAS, - GROUPS, - POOL_TYPE, - KERNEL_H, - KERNEL_W, - DATA_TYPE, - SCALAR, - STRIDE_H, - STRIDE_W, - PADDING_H, - PADDING_W, - AGGR_MODE, - NUM_ENTRIES, - OUT_CHANNELS, - ACTIVATION, - NUMDIM, - AXIS, - PERMUTATION, - OUTSHUFFLE, - MERGE_GCONV_COUNT, - AXES, - KEEP_DIMS, - EPSILON, - PARALLEL_OP_DIM, - PARALLEL_OP_DEGREE, - SOFTMAX_DIM, - NUM_HEADS, - PARALLEL_DIM, - PARALLEL_DEGREE, - PAD, -}; - -template -struct ListIndexAccess { - T attribute_key; - int index; -}; - -template -struct ListSize { - T attribute_key; -}; - -template -using AttributeExpr = mpark::variant, ListSize>; - -enum class TensorDimensionAttribute { SIZE, DEGREE }; - -struct TensorNumDimensionsConstraint { - int value; -}; -struct TensorDimensionAttributeConstraint { - TensorDimensionAttribute attribute; - int index; -}; - -enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; - -using OperatorAttributeValue = - mpark::variant>; -using TensorAttributeValue = mpark::variant>; - -template -struct AttributeConstraint { - ConstraintType constraint_type; - AttributeExpr attribute_expr; - V attribute_value; -}; - -using TensorAttributeConstraint = - AttributeConstraint; -using OperatorAttributeConstraint = - AttributeConstraint; - -struct OperatorPattern { - std::unordered_set attribute_constraints; -}; - -struct ParallelTensorPattern { - std::unordered_set attribute_constraints; -}; - -struct SubstitutionPattern { - OperatorPattern at(utils::Node) const; - ParallelTensorPattern at(PatternEdge) const; - - std::unique_ptr graph; - utils::bidict node_map; - utils::bidict edge_map; -}; - -bool assignment_satisfies( - SubstitutionPattern const &, - std::unordered_map const &nodeAssignment, - std::unordered_map const &edgeAssignment); - -} // namespace substitutions -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 378783115a..e29bae0a92 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -3,10 +3,9 @@ #include namespace FlexFlow { -namespace substitutions { DiGraphPatternMatch narrow_match(DiGraphPatternMatch const &match, - IOpenMultiDiGraphView const &pattern) { + OpenMultiDiGraphView const &pattern) { DiGraphPatternMatch result; std::unordered_set nodes = get_nodes(pattern); for (auto const &kv : match.nodeAssignment) { @@ -27,7 +26,7 @@ DiGraphPatternMatch narrow_match(DiGraphPatternMatch const &match, return result; } -GraphSplit split_pattern(IOpenMultiDiGraphView const &pattern) { +GraphSplit split_pattern(OpenMultiDiGraphView const &pattern) { std::vector topological_ordering = get_topological_ordering(pattern); assert(topological_ordering.size() >= 2); @@ -38,52 +37,16 @@ GraphSplit split_pattern(IOpenMultiDiGraphView const &pattern) { return {prefix, postfix}; } -std::pair, - std::unique_ptr> - apply_split(IOpenMultiDiGraphView const &pattern, GraphSplit const &split) { - return {unsafe_view_as_subgraph(pattern, split.first), - unsafe_view_as_subgraph(pattern, split.second)}; +std::pair + apply_split(OpenMultiDiGraphView const &pattern, GraphSplit const &split) { + return {get_subgraph(pattern, split.first), + get_subgraph(pattern, split.second)}; } -std::unordered_set get_nodes(OpenMultiDiEdge const &pattern_edge) { - if (is_input_edge(pattern_edge)) { - return {mpark::get(pattern_edge).dst}; - } else if (is_output_edge(pattern_edge)) { - return {mpark::get(pattern_edge).src}; - } else { - assert(is_standard_edge(pattern_edge)); - auto standard_edge = mpark::get(pattern_edge); - return {standard_edge.src, standard_edge.dst}; - } -} - -bidict> - get_edge_splits(IOpenMultiDiGraphView const &pattern, - GraphSplit const &split) { - auto prefix = split.first; - auto postfix = split.second; - - bidict> result; - - for (OpenMultiDiEdge const &pattern_edge : get_edges(pattern)) { - if (!is_standard_edge(pattern_edge)) { - continue; - } - - auto standard_edge = mpark::get(pattern_edge); - if (is_subseteq_of(get_nodes(standard_edge), prefix) || - is_subseteq_of(get_nodes(standard_edge), postfix)) { - continue; - } - - auto divided = split_edge(standard_edge); - result.equate(standard_edge, divided); - } - - return result; -} - -MatchSplit apply_split(IOpenMultiDiGraphView const &pattern, +/* +Given a match and a pattern split, gets the submatches in subpatterns. +*/ +MatchSplit apply_split(OpenMultiDiGraphView const &pattern, DiGraphPatternMatch const &match, GraphSplit const &split) { auto prefix = split.first; @@ -129,13 +92,13 @@ MatchSplit apply_split(IOpenMultiDiGraphView const &pattern, return result; } -bool is_singleton_pattern(IOpenMultiDiGraphView const &pattern) { +bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { return num_nodes(pattern) == 1; } template -bool pattern_matches(IOpenMultiDiGraphView const &pattern, - IMultiDiGraph const &graph, +bool pattern_matches(OpenMultiDiGraphView const &pattern, + MultiDiGraphView const &graph, DiGraphPatternMatch const &match, F const &additional_criterion) { if (is_singleton_pattern(pattern)) { @@ -186,9 +149,9 @@ bool pattern_matches(IOpenMultiDiGraphView const &pattern, additional_criterion); } -tl::optional - get_candidate_singleton_match(IOpenMultiDiGraphView const &pattern, - IMultiDiGraphView const &graph, +optional + get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, + MultiDiGraphView const &graph, Node const &graph_node) { assert(is_singleton_pattern(pattern)); @@ -295,5 +258,4 @@ std::unordered_set return matches; } -} // namespace substitutions } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions.cc b/lib/substitutions/src/substitutions.cc new file mode 100644 index 0000000000..28ee9f545a --- /dev/null +++ b/lib/substitutions/src/substitutions.cc @@ -0,0 +1,215 @@ +#include "op-attrs/operator_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "substitutions/get_attribute.h" +#include "substitutions/substitutions.h" + +namespace FlexFlow { + +bool satisfies(Operator const &, + std::vector const &, + OperatorAttributeConstraint const &); + +template +optional + evaluate_list_index_access(ListIndexAccess const &index_access, + optional const &v) { + if (!v.has_value() || + !holds_alternative>(v.value())) { + return nullopt; + } + + auto vec = get>(v.value()); + if (index_access.index >= vec.size()) { + return nullopt; + } + + return vec.at(index_access.index); +} + +template +optional evaluate_list_size(optional const &v) { + if (!v.has_value() || + !holds_alternative>(v.value())) { + return nullopt; + } + + return (int)get>(v.value()).size(); +} + +struct EvaluateOperatorAttributeExpr { + EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} + + optional operator()(OperatorAttributeKey key) { + return get_attribute(this->attrs, key); + } + + optional + operator()(ListIndexAccess const &index_access) { + optional v = + get_attribute(this->attrs, index_access.attribute_key); + return evaluate_list_index_access(index_access, v); + } + + optional + operator()(ListSize const &list_size) { + optional v = + get_attribute(this->attrs, list_size.attribute_key); + return evaluate_list_size(v); + } + +private: + Operator attrs; +}; + +optional + evaluate_tensor_attribute_expr(ParallelTensor const &, + AttributeExpr const &); + +struct EvaluateTensorAttributeExpr { + EvaluateTensorAttributeExpr(ParallelTensor const &tensor_shape) + : tensor_shape(tensor_shape) {} + + template + optional evaluate(T const &t) { + return this->operator()(t); + } + + optional operator()(TensorAttributeKey key) { + switch (key) { + case TensorAttributeKey::DIM_SIZES: { + std::vector result; + for (ParallelDim const &dim : this->tensor_shape) { + result.push_back(dim.size); + } + return result; + } + case TensorAttributeKey::DIM_DEGREES: { + std::vector result; + for (ParallelDim const &dim : this->tensor_shape) { + result.push_back(dim.degree); + } + return result; + } + default: + throw std::runtime_error("Unknown TensorAttributeKey"); + } + } + + optional + operator()(ListIndexAccess const &index_access) { + auto v = this->evaluate(index_access.attribute_key); + return evaluate_list_index_access(index_access, v); + } + + optional + operator()(ListSize const &list_size) { + return evaluate_list_size(this->evaluate(list_size.attribute_key)); + } + +private: + ParallelTensor tensor_shape; +}; + +optional + evaluate_attribute_expr(ParallelTensor const &tensor_shape, + AttributeExpr const &expr) { + return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); +} + +optional + evaluate_attribute_expr(Operator const &attrs, + AttributeExpr const &expr) { + return visit(EvaluateOperatorAttributeExpr(attrs), expr); +} + +template +optional satisfies(ConstraintType constraint_type, + V const &constraint_value, + optional const &maybe_attribute_value) { + if (!maybe_attribute_value.has_value()) { + return nullopt; + } + V attr_val = maybe_attribute_value.value(); + + if (attr_val.index() != constraint_value.index()) { + return nullopt; + } + + if (constraint_type == ConstraintType::EQUAL) { + return attr_val == constraint_value; + } else { + throw std::runtime_error("Unknown constraint_type"); + } +} + +optional satisfies(ParallelTensor const &tensor_shape, + TensorAttributeConstraint const &constraint) { + auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); + return satisfies( + constraint.constraint_type, constraint.attribute_value, value); +} + +optional satisfies(Operator const ¶ms, + OperatorAttributeConstraint const &constraint) { + auto value = evaluate_attribute_expr(params, constraint.attribute_expr); + return satisfies( + constraint.constraint_type, constraint.attribute_value, value); +} + +template +optional optional_all_of(Container const &container, + Function const &func) { + for (auto const &element : container) { + optional condition = func(element); + if (!condition.has_value()) { + return nullopt; + } + + if (!condition.value()) { + return false; + } + } + return true; +} + +optional satisfies(Operator const ¶ms, + OperatorPattern const &pattern) { + return optional_all_of(pattern.attribute_constraints, + [&](OperatorAttributeConstraint const &c) { + return satisfies(params, c); + }); +} + +optional satisfies(ParallelTensor const ¶ms, + ParallelTensorPattern const &pattern) { + return optional_all_of( + pattern.attribute_constraints, + [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); +} + +bool assignment_satisfies( + ParallelComputationGraph const &pcg, + SubstitutionPattern const &pattern, + DiGraphPatternMatch const &patternMatch) { + bool result = true; + for (auto const &kv : patternMatch.nodeAssignment) { + auto patternNode = kv.first; + auto pcgNode = kv.second; + optional constraintResult = + satisfies(pcg.at(pcgNode), pattern.at(patternNode)); + result &= constraintResult.value_or(false); + } + + for (auto const &kv : patternMatch.edgeAssignment) { + auto patternEdge = kv.first; + auto pcgEdge = kv.second; + optional constraintResult = + satisfies(pcg.at(pcgEdge), pattern.at(patternEdge)); + result &= constraintResult.value_or(false); + } + + result &= pattern_matches(OpenMultiDiGraphView(pattern), MultiDiGraphView(pcg), patternMatch); + + return result; +} +} // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 9a36b2c601..e08ef30ccd 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -22,6 +22,8 @@ std::vector add_nodes(Graph &, int); std::unordered_set get_nodes(GraphView const &); std::unordered_set get_node_ports(MultiDiGraphView const &); +std::unordered_set get_nodes(OpenMultiDiEdge const &); + std::unordered_set query_nodes(GraphView const &, std::unordered_set const &); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 28bfaa77b7..44335babcf 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -22,6 +22,18 @@ std::unordered_set get_nodes(GraphView const &g) { return g.query_nodes(NodeQuery::all()); } +std::unordered_set get_nodes(OpenMultiDiEdge const &pattern_edge) { + if (is_input_edge(pattern_edge)) { + return {mpark::get(pattern_edge).dst}; + } else if (is_output_edge(pattern_edge)) { + return {mpark::get(pattern_edge).src}; + } else { + assert(is_standard_edge(pattern_edge)); + auto standard_edge = mpark::get(pattern_edge); + return {standard_edge.src, standard_edge.dst}; + } +} + std::unordered_set query_nodes(IGraphView const &g, std::unordered_set const &nodes) { return g.query_nodes({nodes}); @@ -468,6 +480,32 @@ MultiDiEdge unsplit_edge(OutputMultiDiEdge const &output_edge, output_edge.src, input_edge.dst, output_edge.srcIdx, input_edge.dstIdx}; } +bidict> + get_edge_splits(IOpenMultiDiGraphView const &pattern, + GraphSplit const &split) { + auto prefix = split.first; + auto postfix = split.second; + + bidict> result; + + for (OpenMultiDiEdge const &pattern_edge : get_edges(pattern)) { + if (!is_standard_edge(pattern_edge)) { + continue; + } + + auto standard_edge = mpark::get(pattern_edge); + if (is_subseteq_of(get_nodes(standard_edge), prefix) || + is_subseteq_of(get_nodes(standard_edge), postfix)) { + continue; + } + + auto divided = split_edge(standard_edge); + result.equate(standard_edge, divided); + } + + return result; +} + std::unordered_set get_cut(OpenMultiDiGraphView const &g, GraphSplit const &s) { return keys(get_edge_splits(g, s)); From 6bb76df9f72a4597d1865fc269dccac080940dc9 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sat, 19 Aug 2023 15:37:28 -0400 Subject: [PATCH 18/61] format --- lib/substitutions/src/substitutions.cc | 37 ++++++++++++-------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/lib/substitutions/src/substitutions.cc b/lib/substitutions/src/substitutions.cc index 28ee9f545a..cb44684148 100644 --- a/lib/substitutions/src/substitutions.cc +++ b/lib/substitutions/src/substitutions.cc @@ -1,7 +1,7 @@ +#include "substitutions/substitutions.h" #include "op-attrs/operator_attrs.h" #include "op-attrs/parallel_tensor_shape.h" #include "substitutions/get_attribute.h" -#include "substitutions/substitutions.h" namespace FlexFlow { @@ -10,11 +10,9 @@ bool satisfies(Operator const &, OperatorAttributeConstraint const &); template -optional - evaluate_list_index_access(ListIndexAccess const &index_access, - optional const &v) { - if (!v.has_value() || - !holds_alternative>(v.value())) { +optional evaluate_list_index_access(ListIndexAccess const &index_access, + optional const &v) { + if (!v.has_value() || !holds_alternative>(v.value())) { return nullopt; } @@ -28,8 +26,7 @@ optional template optional evaluate_list_size(optional const &v) { - if (!v.has_value() || - !holds_alternative>(v.value())) { + if (!v.has_value() || !holds_alternative>(v.value())) { return nullopt; } @@ -124,8 +121,8 @@ optional template optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - optional const &maybe_attribute_value) { + V const &constraint_value, + optional const &maybe_attribute_value) { if (!maybe_attribute_value.has_value()) { return nullopt; } @@ -143,14 +140,14 @@ optional satisfies(ConstraintType constraint_type, } optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { + TensorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); return satisfies( constraint.constraint_type, constraint.attribute_value, value); } optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { + OperatorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(params, constraint.attribute_expr); return satisfies( constraint.constraint_type, constraint.attribute_value, value); @@ -158,7 +155,7 @@ optional satisfies(Operator const ¶ms, template optional optional_all_of(Container const &container, - Function const &func) { + Function const &func) { for (auto const &element : container) { optional condition = func(element); if (!condition.has_value()) { @@ -173,7 +170,7 @@ optional optional_all_of(Container const &container, } optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { + OperatorPattern const &pattern) { return optional_all_of(pattern.attribute_constraints, [&](OperatorAttributeConstraint const &c) { return satisfies(params, c); @@ -181,16 +178,15 @@ optional satisfies(Operator const ¶ms, } optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { + ParallelTensorPattern const &pattern) { return optional_all_of( pattern.attribute_constraints, [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); } -bool assignment_satisfies( - ParallelComputationGraph const &pcg, - SubstitutionPattern const &pattern, - DiGraphPatternMatch const &patternMatch) { +bool assignment_satisfies(ParallelComputationGraph const &pcg, + SubstitutionPattern const &pattern, + DiGraphPatternMatch const &patternMatch) { bool result = true; for (auto const &kv : patternMatch.nodeAssignment) { auto patternNode = kv.first; @@ -208,7 +204,8 @@ bool assignment_satisfies( result &= constraintResult.value_or(false); } - result &= pattern_matches(OpenMultiDiGraphView(pattern), MultiDiGraphView(pcg), patternMatch); + result &= pattern_matches( + OpenMultiDiGraphView(pattern), MultiDiGraphView(pcg), patternMatch); return result; } From 4d5d8de45c6997a9e3259e392dcf42d0525efffc Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 20 Aug 2023 23:29:49 -0400 Subject: [PATCH 19/61] draft substitutions --- lib/compiler/src/substitution_logic.cc | 250 ------------ .../include/substitutions/attribute_expr.h | 33 ++ .../substitutions/attribute_expr_binding.h | 27 ++ .../include/substitutions/graph_pattern.h | 32 +- .../substitutions/graph_pattern_match.h | 29 ++ .../include/substitutions/operator_pattern.h | 53 +++ .../include/substitutions/output_graph.h | 60 +++ .../substitutions/parallel_tensor_pattern.h | 32 ++ .../include/substitutions/substitution.h | 18 + .../include/substitutions/substitutions.h | 117 ------ .../include/substitutions/substitutions_old.h | 101 ----- lib/substitutions/src/graph_pattern.cc | 383 ++++++++---------- lib/substitutions/src/graph_pattern_match.cc | 261 ++++++++++++ lib/substitutions/src/operator_attributes.cc | 179 ++++---- lib/substitutions/src/substitutions.cc | 212 ---------- 15 files changed, 780 insertions(+), 1007 deletions(-) delete mode 100644 lib/compiler/src/substitution_logic.cc create mode 100644 lib/substitutions/include/substitutions/attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/attribute_expr_binding.h create mode 100644 lib/substitutions/include/substitutions/graph_pattern_match.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern.h create mode 100644 lib/substitutions/include/substitutions/output_graph.h create mode 100644 lib/substitutions/include/substitutions/parallel_tensor_pattern.h create mode 100644 lib/substitutions/include/substitutions/substitution.h delete mode 100644 lib/substitutions/include/substitutions/substitutions.h delete mode 100644 lib/substitutions/include/substitutions/substitutions_old.h create mode 100644 lib/substitutions/src/graph_pattern_match.cc delete mode 100644 lib/substitutions/src/substitutions.cc diff --git a/lib/compiler/src/substitution_logic.cc b/lib/compiler/src/substitution_logic.cc deleted file mode 100644 index b1554dfe5a..0000000000 --- a/lib/compiler/src/substitution_logic.cc +++ /dev/null @@ -1,250 +0,0 @@ -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "substitutions/get_attribute.h" -#include "substitutions/substitutions_v2.h" - -// using namespace ::FlexFlow::opmeta; -using namespace ::FlexFlow::substitutions; - -namespace FlexFlow { -namespace ffc { - -bool satisfies(PCGOperatorAttrs const &, - std::vector const &, - OperatorAttributeConstraint const &); - -/* tl::optional satisfies(OperatorParameters const ¶ms, - * OperatorConstraint const &constraint) { */ - -/* } */ - -/* struct SatisfiesAttributeConstraint { */ -/* SatisfiesAttributeConstraint(OperatorAttributeConstraint const &constraint) - */ -/* : constraint(constraint) */ -/* { } */ - -/* tl::optional satisfies(OperatorParameters const &attrs) */ - -/* private: */ -/* OperatorAttributeConstraint constraint; */ -/* }; */ - -/* tl::optional< */ - -/* tl::optional get_attribute( */ - -/* tl::optional get_attribute(Conv2D */ - -template -tl::optional - evaluate_list_index_access(ListIndexAccess const &index_access, - tl::optional const &v) { - if (!v.has_value() || - !mpark::holds_alternative>(v.value())) { - return tl::nullopt; - } - - auto vec = mpark::get>(v.value()); - if (index_access.index >= vec.size()) { - return tl::nullopt; - } - - return vec.at(index_access.index); -} - -template -tl::optional evaluate_list_size(tl::optional const &v) { - if (!v.has_value() || - !mpark::holds_alternative>(v.value())) { - return tl::nullopt; - } - - return (int)mpark::get>(v.value()).size(); -} - -struct EvaluateOperatorAttributeExpr { - EvaluateOperatorAttributeExpr(PCGOperatorAttrs const &attrs) : attrs(attrs) {} - - tl::optional operator()(OperatorAttributeKey key) { - return get_attribute(this->attrs, key); - } - - tl::optional - operator()(ListIndexAccess const &index_access) { - tl::optional v = - get_attribute(this->attrs, index_access.attribute_key); - return evaluate_list_index_access(index_access, v); - } - - tl::optional - operator()(ListSize const &list_size) { - tl::optional v = - get_attribute(this->attrs, list_size.attribute_key); - return evaluate_list_size(v); - } - -private: - PCGOperatorAttrs attrs; -}; - -tl::optional - evaluate_tensor_attribute_expr(ParallelTensorShape const &, - AttributeExpr const &); - -struct EvaluateTensorAttributeExpr { - EvaluateTensorAttributeExpr(ParallelTensorShape const &tensor_shape) - : tensor_shape(tensor_shape) {} - - template - tl::optional evaluate(T const &t) { - return this->operator()(t); - } - - tl::optional operator()(TensorAttributeKey key) { - switch (key) { - case TensorAttributeKey::DIM_SIZES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape) { - result.push_back(dim.size); - } - return result; - } - case TensorAttributeKey::DIM_DEGREES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape) { - result.push_back(dim.degree); - } - return result; - } - default: - throw std::runtime_error("Unknown TensorAttributeKey"); - } - } - - tl::optional - operator()(ListIndexAccess const &index_access) { - auto v = this->evaluate(index_access.attribute_key); - return evaluate_list_index_access(index_access, v); - } - - tl::optional - operator()(ListSize const &list_size) { - return evaluate_list_size(this->evaluate(list_size.attribute_key)); - } - -private: - ParallelTensorShape tensor_shape; -}; - -tl::optional - evaluate_attribute_expr(ParallelTensorShape const &tensor_shape, - AttributeExpr const &expr) { - return mpark::visit(EvaluateTensorAttributeExpr(tensor_shape), expr); -} - -tl::optional - evaluate_attribute_expr(PCGOperatorAttrs const &attrs, - AttributeExpr const &expr) { - return mpark::visit(EvaluateOperatorAttributeExpr(attrs), expr); -} - -template -tl::optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - tl::optional const &maybe_attribute_value) { - /* tl::optional maybe_attr_val = evalute_attribute_expr(attrs, - * constraint.attribute_expr); */ - - if (!maybe_attribute_value.has_value()) { - return tl::nullopt; - } - V attr_val = maybe_attribute_value.value(); - - if (attr_val.index() != constraint_value.index()) { - return tl::nullopt; - } - - if (constraint_type == ConstraintType::EQUAL) { - return attr_val == constraint_value; - } else { - throw std::runtime_error("Unknown constraint_type"); - } -} - -tl::optional satisfies(ParallelTensorShape const &tensor_shape, - TensorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -tl::optional satisfies(PCGOperatorAttrs const ¶ms, - OperatorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(params, constraint.attribute_expr); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -template -tl::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - tl::optional condition = func(element); - if (!condition.has_value()) { - return tl::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -tl::optional satisfies(PCGOperatorAttrs const ¶ms, - OperatorPattern const &pattern) { - return optional_all_of(pattern.attribute_constraints, - [&](OperatorAttributeConstraint const &c) { - return satisfies(params, c); - }); -} - -tl::optional satisfies(ParallelTensorShape const ¶ms, - ParallelTensorPattern const &pattern) { - return optional_all_of( - pattern.attribute_constraints, - [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); -} - -bool assignment_satisfies( - IMultiDiGraph const &pcg, - SubstitutionPattern const &pattern, - DiGraphPatternMatch const &patternMatch, - std::unordered_map const &pcgNodeParams, - std::unordered_map const - &pcgTensorShapes) { - bool result = true; - for (auto const &kv : patternMatch.nodeAssignment) { - auto patternNode = kv.first; - auto pcgNode = kv.second; - tl::optional constraintResult = - satisfies(pcgNodeParams.at(pcgNode), pattern.at(patternNode)); - result &= constraintResult.value_or(false); - } - - for (auto const &kv : patternMatch.edgeAssignment) { - auto patternEdge = kv.first; - auto pcgEdge = kv.second; - tl::optional constraintResult = - satisfies(pcgTensorShapes.at(pcgEdge), pattern.at(patternEdge)); - result &= constraintResult.value_or(false); - } - - result &= pattern_matches(*pattern.graph, pcg, patternMatch); - - return result; -} - -} // namespace ffc -} // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h new file mode 100644 index 0000000000..52dd6558af --- /dev/null +++ b/lib/substitutions/include/substitutions/attribute_expr.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H +#define _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H + +#include "mpark/variant.hpp" + +namespace FlexFlow { + +enum class ConstraintType { EQUAL }; + +template +struct ListIndexAccess { + T attribute_key; + int index; +}; + +template +struct ListSize { + T attribute_key; +}; + +template +using AttributeExpr = variant, ListSize>; + +template +struct AttributeConstraint { + ConstraintType constraint_type; + AttributeExpr attribute_expr; + V attribute_value; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/attribute_expr_binding.h b/lib/substitutions/include/substitutions/attribute_expr_binding.h new file mode 100644 index 0000000000..f719cfaf9d --- /dev/null +++ b/lib/substitutions/include/substitutions/attribute_expr_binding.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_SUBSTITUTIONS_ATTRIBUTE_EXPR_BINDING_H +#define _FLEXFLOW_SUBSTITUTIONS_ATTRIBUTE_EXPR_BINDING_H + +#include "attribute_expr.h" + +namespace FlexFlow { + +struct attr_expr_id : public strong_typedef { + using strong_typedef::strong_typedef; +}; + +template +struct AttributeExprBinding { + void add_expr(attr_expr_id const &id, AttributeExpr const &expr) { + binding.emplace(id, expr); + } + + GraphAttributeExpr get_expr(attr_expr_id const &id) const { + return binding.at(id); + } +private: + std::unordered_map> binding; +}; + +} + +#endif diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index d7654e19ce..8a2bd015bc 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -1,33 +1,25 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_H +#ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H +#define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H -#include "utils/graph.h" +#include "graph_pattern_match.h" namespace FlexFlow { -struct DiGraphPatternMatch { - bidict nodeAssignment; - req> edgeAssignment; +struct GraphPattern + : public strong_typedef< + GraphPattern, + LabelledOpenMultiDiGraph> { + using strong_typedef::strong_typedef; }; -FF_VISITABLE_STRUCT(DiGraphPatternMatch, nodeAssignment, edgeAssignment); - -struct MatchSplit { - DiGraphPatternMatch prefix_submatch; - req postfix_submatch; -}; - -FF_VISITABLE_STRUCT(MatchSplit, prefix_submatch, postfix_submatch); - GraphSplit split_pattern(OpenMultiDiGraphView const &pattern); -bool pattern_matches(OpenMultiDiGraphView const &, - MultiDiGraphView const &, - DiGraphPatternMatch const &, - F const &additional_criterion); - bool is_singleton_pattern(OpenMultiDiGraphView const &); +bool assignment_satisfies(ParallelComputationGraph const &, + GraphPattern const &, + DiGraphPatternMatch const &); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h new file mode 100644 index 0000000000..2e7ddd852c --- /dev/null +++ b/lib/substitutions/include/substitutions/graph_pattern_match.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_H +#define _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_H + +#include "utils/graph.h" + +namespace FlexFlow { + +struct DiGraphPatternMatch { + bidict nodeAssignment; + req> edgeAssignment; +}; + +FF_VISITABLE_STRUCT(DiGraphPatternMatch, nodeAssignment, edgeAssignment); + +struct MatchSplit { + DiGraphPatternMatch prefix_submatch; + req postfix_submatch; +}; + +FF_VISITABLE_STRUCT(MatchSplit, prefix_submatch, postfix_submatch); + +bool pattern_matches(OpenMultiDiGraphView const &, + MultiDiGraphView const &, + DiGraphPatternMatch const &, + F const &additional_criterion); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h new file mode 100644 index 0000000000..288f01970b --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -0,0 +1,53 @@ +#ifndef _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H +#define _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H + +#include "constraint.h" + +namespace FlexFlow { + +enum class OperatorAttributeKey { + OP_TYPE, // AnyOp + USE_BIAS, + GROUPS, + POOL_TYPE, + KERNEL_H, + KERNEL_W, + DATA_TYPE, + SCALAR, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + AGGR_MODE, + NUM_ENTRIES, + OUT_CHANNELS, + ACTIVATION, + NUMDIM, + AXIS, + PERMUTATION, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + PARALLEL_OP_DIM, + PARALLEL_OP_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD, +}; + +using OperatorAttributeValue = variant, OperatorType, Activation>; + +using OperatorAttributeConstraint = + AttributeConstraint; + +struct OperatorPattern { + std::unordered_set attribute_constraints; +}; + +} + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h new file mode 100644 index 0000000000..146c3002cf --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -0,0 +1,60 @@ +#ifndef _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H +#define _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H + +#include "utils/graph.h" + +namespace FlexFlow { + +// NOTE(@wmdi) I am not sure whether these should be part of attribute expr. +template +struct NodeAttrAccess { + Node node; + T attr_expr; +}; + +template +struct EdgeAttrAccess { + OpenMultiDiEdge edge; + T attr_expr; +}; + +enum class AttrBinaryOpType { + ADD, + SUB, + MUL, + DIV +}; + +template +struct AttrBinary { + AttrBinaryOpType op_type; + L lhs; + R rhs; +}; + +template +using GraphAttributeExpr = variant, EdgeAttrAccess>; + +template +using GraphAttributeExpr = AttrBinary; + +using GraphAttributeValue = variant, OperatorType, Activation>; + +// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can define the assignment for each operator type. +template +struct OperatorAttrAssignment { + std::vector>> assignment; +}; + +template +struct ParallelTensorAttrAssignment { + std::vector>> assignment; +}; + +struct OutputGraph : public strong_typedef { + using strong_typedef::strong_typedef; +}; + +} + +#endif diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h new file mode 100644 index 0000000000..f04abfd441 --- /dev/null +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H +#define _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H + +#include "constraint.h" + +namespace FlexFlow { + +enum class TensorDimensionAttribute { SIZE, DEGREE }; + +struct TensorNumDimensionsConstraint { + int value; +}; + +struct TensorDimensionAttributeConstraint { + TensorDimensionAttribute attribute; + int index; +}; + +enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; + +using TensorAttributeValue = variant>; + +using TensorAttributeConstraint = + AttributeConstraint; + +struct ParallelTensorPattern { + std::unordered_set attribute_constraints; +}; + +} + +#endif diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h new file mode 100644 index 0000000000..ae68063d35 --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H +#define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H + +#include "graph_pattern.h" +#include "output_graph.h" + +namespace FlexFlow { + +struct Substitution { + GraphPattern input_graph; + OutputGraph output_graph; +}; + +ParallelComputationGraph apply_substitution(ParallelComputationGraph const &, Substitution const &, DiGraphPatternMatch const &); + +} + +#endif diff --git a/lib/substitutions/include/substitutions/substitutions.h b/lib/substitutions/include/substitutions/substitutions.h deleted file mode 100644 index 528f93b355..0000000000 --- a/lib/substitutions/include/substitutions/substitutions.h +++ /dev/null @@ -1,117 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_V2_H -#define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_V2_H - -#include "graph_pattern.h" -#include "mpark/variant.hpp" -#include "utils/bidict.h" -#include "utils/graph.h" - -namespace FlexFlow { - -enum class ConstraintType { EQUAL }; - -enum class OperatorAttributeKey { - OP_TYPE, // AnyOp - USE_BIAS, - GROUPS, - POOL_TYPE, - KERNEL_H, - KERNEL_W, - DATA_TYPE, - SCALAR, - STRIDE_H, - STRIDE_W, - PADDING_H, - PADDING_W, - AGGR_MODE, - NUM_ENTRIES, - OUT_CHANNELS, - ACTIVATION, - NUMDIM, - AXIS, - PERMUTATION, - OUTSHUFFLE, - MERGE_GCONV_COUNT, - AXES, - KEEP_DIMS, - EPSILON, - PARALLEL_OP_DIM, - PARALLEL_OP_DEGREE, - SOFTMAX_DIM, - NUM_HEADS, - PARALLEL_DIM, - PARALLEL_DEGREE, - PAD, -}; - -template -struct ListIndexAccess { - T attribute_key; - int index; -}; - -template -struct ListSize { - T attribute_key; -}; - -template -using AttributeExpr = variant, ListSize>; - -enum class TensorDimensionAttribute { SIZE, DEGREE }; - -struct TensorNumDimensionsConstraint { - int value; -}; -struct TensorDimensionAttributeConstraint { - TensorDimensionAttribute attribute; - int index; -}; - -enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; - -using OperatorAttributeValue = variant>; -using TensorAttributeValue = variant>; - -template -struct AttributeConstraint { - ConstraintType constraint_type; - AttributeExpr attribute_expr; - V attribute_value; -}; - -using TensorAttributeConstraint = - AttributeConstraint; -using OperatorAttributeConstraint = - AttributeConstraint; - -struct OperatorPattern { - std::unordered_set attribute_constraints; -}; - -struct ParallelTensorPattern { - std::unordered_set attribute_constraints; -}; - -struct SubstitutionPattern - : public strong_typedef< - SubstitutionPattern, - LabelledOpenMultiDiGraph> { - using strong_typedef::strong_typedef; -}; - -// struct SubstitutionPattern { -// OperatorPattern at(utils::Node) const; -// ParallelTensorPattern at(PatternEdge) const; - -// MultiDiGraphPattern graph; -// utils::bidict node_map; -// utils::bidict edge_map; -// }; - -bool assignment_satisfies(SubstitutionPattern const &, - DiGraphPatternMatch const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/substitutions_old.h b/lib/substitutions/include/substitutions/substitutions_old.h deleted file mode 100644 index 33ffc9704d..0000000000 --- a/lib/substitutions/include/substitutions/substitutions_old.h +++ /dev/null @@ -1,101 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H -#define _FLEXFLOW_SUBSTITUTION_LOADER_H - -#include "op-meta/op-meta.h" -#include "tl/optional.hpp" -#include -#include - -namespace FlexFlow { -namespace substitutions { - -enum class ParameterAttribute { - OP_TYPE, // AnyOp - NUM_INPUTS, // AnyOp - NUM_OUTPUTS, // AnyOp - GROUP, // Conv2D - KERNEL_H, // Conv2D, Pool2D - KERNEL_W, // Conv2D, Pool2D - STRIDE_H, // Conv2D, Pool2D - STRIDE_W, // Conv2D, Pool2D - PADDING_H, // Conv2D, Pool2D - PADDING_W, // Conv2D, Pool2D - ACTIVATION, // Conv2D, Pool2D - NUMDIM, // Concat, Transpose - AXIS, // Concat, Split - PERM, // Transpose - OUTSHUFFLE, // Transpose - MERGE_GCONV_COUNT, // MergeGConv - AXES, // Squeeze, Unsqueeze, Reduce* - KEEP_DIMS, // Reduce* - EPSILON, // BatchNorm - REPARTITION_DIM, // Repartition - REPARTITION_DEGREE, // Repartition - REPLICATE_DIM, // Replicate - REPLICATE_DEGREE, // Replicate - COMBINE_DIM, // Combine - COMBINE_DEGREE, // Combine - REDUCTION_DIM, // Reduction - REDUCTION_DEGREE, // Reduction - SOFTMAX_DIM, // Softmax - NUM_HEADS, // MultiHeadAttention - INVALID, - PARALLEL_DIM, - PARALLEL_DEGREE, - PAD, -}; - -enum class ConstraintType { - Equal, - NotEqual, - LessThan, - LessThanEqual, - GreaterThan, - GreaterThanEqual, -}; - -struct OperatorAttributeConstraint { - ParameterAttribute key; - ConstraintType constraint; - int value; -}; - -struct TensorConstraint {}; - -struct Tensor { - int opId; - int tsId; - - std::vector constraints; -}; - -struct OperatorConstraint { - OperatorType op_type; - std::vector inputs; - std::vector constraints; - - tl::optional at(ParameterAttribute key) const; -}; - -struct MapOutput { - int dstOpId; - int dstTsId; - int srcOpId; - int srcTsId; -}; - -struct Substitution { - std::string name; - std::vector srcOp; - std::vector dstOp; - std::vector mappedOutput; -}; - -struct SubstitutionCollection { - std::vector substitutions; -}; - -} // namespace substitutions -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index e29bae0a92..59c9ed6b0b 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -1,261 +1,212 @@ -#include "substitutions/graph_pattern.h" -#include "utils/hash-utils.h" -#include +#include "op-attrs/operator_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "substitutions/get_attribute.h" +#include "substitutions/substitutions.h" namespace FlexFlow { -DiGraphPatternMatch narrow_match(DiGraphPatternMatch const &match, - OpenMultiDiGraphView const &pattern) { - DiGraphPatternMatch result; - std::unordered_set nodes = get_nodes(pattern); - for (auto const &kv : match.nodeAssignment) { - Node pattern_node = kv.first; - if (contains(nodes, pattern_node)) { - result.nodeAssignment.equate(kv.first, kv.second); - } +bool satisfies(Operator const &, + std::vector const &, + OperatorAttributeConstraint const &); + +template +optional evaluate_list_index_access(ListIndexAccess const &index_access, + optional const &v) { + if (!v.has_value() || !holds_alternative>(v.value())) { + return nullopt; } - std::unordered_set edges = get_edges(pattern); - for (auto const &kv : match.edgeAssignment) { - OpenMultiDiEdge pattern_edge = kv.first; - if (contains(edges, pattern_edge)) { - result.edgeAssignment.equate(kv.first, kv.second); - } + auto vec = get>(v.value()); + if (index_access.index >= vec.size()) { + return nullopt; } - return result; + return vec.at(index_access.index); } -GraphSplit split_pattern(OpenMultiDiGraphView const &pattern) { - std::vector topological_ordering = get_topological_ordering(pattern); - assert(topological_ordering.size() >= 2); +template +optional evaluate_list_size(optional const &v) { + if (!v.has_value() || !holds_alternative>(v.value())) { + return nullopt; + } - int split_point = topological_ordering.size() / 2; - auto split = vector_split(topological_ordering, split_point); - std::unordered_set prefix(split.first.begin(), split.first.end()); - std::unordered_set postfix(split.second.begin(), split.second.end()); - return {prefix, postfix}; + return (int)get>(v.value()).size(); } -std::pair - apply_split(OpenMultiDiGraphView const &pattern, GraphSplit const &split) { - return {get_subgraph(pattern, split.first), - get_subgraph(pattern, split.second)}; -} +struct EvaluateOperatorAttributeExpr { + EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} -/* -Given a match and a pattern split, gets the submatches in subpatterns. -*/ -MatchSplit apply_split(OpenMultiDiGraphView const &pattern, - DiGraphPatternMatch const &match, - GraphSplit const &split) { - auto prefix = split.first; - auto postfix = split.second; - - MatchSplit result; - - for (auto const &kv : match.nodeAssignment) { - Node pattern_node = kv.first; - Node graph_node = kv.second; - if (contains(split.first, pattern_node)) { - result.prefix_submatch.nodeAssignment.equate(pattern_node, graph_node); - } else { - assert(contains(split.second, pattern_node)); - result.postfix_submatch.nodeAssignment.equate(pattern_node, graph_node); - } + optional operator()(OperatorAttributeKey key) { + return get_attribute(this->attrs, key); } - auto edge_splits = get_edge_splits(pattern, split); - - std::function handle_edge = - [&](OpenMultiDiEdge const &pattern_edge) -> void { - MultiDiEdge graph_edge = match.edgeAssignment.at_l(pattern_edge); - auto edge_nodes = get_nodes(pattern_edge); - if (is_subseteq_of(edge_nodes, prefix)) { - result.prefix_submatch.edgeAssignment.equate(pattern_edge, graph_edge); - } else if (is_subseteq_of(edge_nodes, postfix)) { - result.postfix_submatch.edgeAssignment.equate(pattern_edge, graph_edge); - } else { - assert(is_standard_edge(pattern_edge)); - auto standard_edge = mpark::get(pattern_edge); - auto divided = edge_splits.at_l(standard_edge); - handle_edge(divided.first); - handle_edge(divided.second); - } - }; + optional + operator()(ListIndexAccess const &index_access) { + optional v = + get_attribute(this->attrs, index_access.attribute_key); + return evaluate_list_index_access(index_access, v); + } - for (auto const &kv : match.edgeAssignment) { - OpenMultiDiEdge pattern_edge = kv.first; - handle_edge(pattern_edge); + optional + operator()(ListSize const &list_size) { + optional v = + get_attribute(this->attrs, list_size.attribute_key); + return evaluate_list_size(v); } - return result; -} +private: + Operator attrs; +}; -bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { - return num_nodes(pattern) == 1; -} +optional + evaluate_tensor_attribute_expr(ParallelTensor const &, + AttributeExpr const &); -template -bool pattern_matches(OpenMultiDiGraphView const &pattern, - MultiDiGraphView const &graph, - DiGraphPatternMatch const &match, - F const &additional_criterion) { - if (is_singleton_pattern(pattern)) { - Node pattern_node = get_only(get_nodes(pattern)); - Node graph_matched_node = match.nodeAssignment.at_l(pattern_node); - if (!additional_criterion(pattern_node, graph_matched_node)) { - return false; - } - for (OpenMultiDiEdge const &e : get_edges(pattern)) { - MultiDiEdge graph_matched_edge = match.edgeAssignment.at_l(e); - - assert(is_input_edge(e) || is_output_edge(e)); - if (is_input_edge(e)) { - InputMultiDiEdge input_edge = mpark::get(e); - if (match.nodeAssignment.at_l(input_edge.dst) != - graph_matched_edge.dst || - input_edge.dstIdx != graph_matched_edge.dstIdx) { - return false; - } - } else { - OutputMultiDiEdge output_edge = mpark::get(e); - if (match.nodeAssignment.at_l(output_edge.src) != - graph_matched_edge.src || - output_edge.srcIdx != graph_matched_edge.srcIdx) { - return false; +struct EvaluateTensorAttributeExpr { + EvaluateTensorAttributeExpr(ParallelTensor const &tensor_shape) + : tensor_shape(tensor_shape) {} + + template + optional evaluate(T const &t) { + return this->operator()(t); + } + + optional operator()(TensorAttributeKey key) { + switch (key) { + case TensorAttributeKey::DIM_SIZES: { + std::vector result; + for (ParallelDim const &dim : this->tensor_shape) { + result.push_back(dim.size); } + return result; } - - if (!additional_criterion(e, graph_matched_edge)) { - return false; + case TensorAttributeKey::DIM_DEGREES: { + std::vector result; + for (ParallelDim const &dim : this->tensor_shape) { + result.push_back(dim.degree); + } + return result; } + default: + throw std::runtime_error("Unknown TensorAttributeKey"); } - - return true; } - auto split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto submatches = apply_split(pattern, match, split); - - return pattern_matches(*subpatterns.first, - graph, - submatches.prefix_submatch, - additional_criterion) && - pattern_matches(*subpatterns.second, - graph, - submatches.postfix_submatch, - additional_criterion); -} + optional + operator()(ListIndexAccess const &index_access) { + auto v = this->evaluate(index_access.attribute_key); + return evaluate_list_index_access(index_access, v); + } -optional - get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, - MultiDiGraphView const &graph, - Node const &graph_node) { - assert(is_singleton_pattern(pattern)); - - Node pattern_node = get_only(get_nodes(pattern)); - - DiGraphPatternMatch match; - match.nodeAssignment.equate(pattern_node, graph_node); - - auto incoming = get_incoming_edges_by_idx(graph, graph_node); - auto outgoing = get_outgoing_edges_by_idx(graph, graph_node); - for (OpenMultiDiEdge const &pattern_edge : get_edges(pattern)) { - assert(is_input_edge(pattern_edge) || is_output_edge(pattern_edge)); - if (is_input_edge(pattern_edge)) { - InputMultiDiEdge input_edge = mpark::get(pattern_edge); - if (!contains_key(incoming, input_edge.dstIdx)) { - return tl::nullopt; - } - match.edgeAssignment.equate(input_edge, - get_only(incoming.at(input_edge.dstIdx))); - } else { - OutputMultiDiEdge output_edge = - mpark::get(pattern_edge); - if (!contains_key(outgoing, output_edge.srcIdx)) { - return tl::nullopt; - } - match.edgeAssignment.equate(output_edge, - get_only(outgoing.at(output_edge.srcIdx))); - } + optional + operator()(ListSize const &list_size) { + return evaluate_list_size(this->evaluate(list_size.attribute_key)); } - return match; +private: + ParallelTensor tensor_shape; +}; + +optional + evaluate_attribute_expr(ParallelTensor const &tensor_shape, + AttributeExpr const &expr) { + return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); } -tl::optional unsplit_matches( - DiGraphPatternMatch const &prefix, - DiGraphPatternMatch const &postfix, - bidict> const - &edge_splits) { - DiGraphPatternMatch result; - std::unordered_set handled; - for (auto const &kv : edge_splits) { - MultiDiEdge standard_edge = kv.first; - OutputMultiDiEdge output_edge = kv.second.first; - InputMultiDiEdge input_edge = kv.second.second; - handled.insert(output_edge); - handled.insert(input_edge); - - MultiDiEdge output_graph_edge = prefix.edgeAssignment.at_l(output_edge); - MultiDiEdge input_graph_edge = postfix.edgeAssignment.at_l(input_edge); - if (output_graph_edge == input_graph_edge) { - result.edgeAssignment.equate(standard_edge, output_graph_edge); - } else { - return tl::nullopt; - } +optional + evaluate_attribute_expr(Operator const &attrs, + AttributeExpr const &expr) { + return visit(EvaluateOperatorAttributeExpr(attrs), expr); +} + +template +optional satisfies(ConstraintType constraint_type, + V const &constraint_value, + optional const &maybe_attribute_value) { + if (!maybe_attribute_value.has_value()) { + return nullopt; } + V attr_val = maybe_attribute_value.value(); - for (auto const &kv : - merge_maps(prefix.edgeAssignment, postfix.edgeAssignment)) { - if (!contains(handled, kv.first)) { - result.edgeAssignment.equate(kv.first, kv.second); - } + if (attr_val.index() != constraint_value.index()) { + return nullopt; } - result.nodeAssignment = - merge_maps(prefix.nodeAssignment, postfix.nodeAssignment); + if (constraint_type == ConstraintType::EQUAL) { + return attr_val == constraint_value; + } else { + throw std::runtime_error("Unknown constraint_type"); + } +} - return result; +optional satisfies(ParallelTensor const &tensor_shape, + TensorAttributeConstraint const &constraint) { + auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); + return satisfies( + constraint.constraint_type, constraint.attribute_value, value); } -template -std::unordered_set - find_pattern_matches(IOpenMultiDiGraphView const &pattern, - IMultiDiGraph const &graph, - F const &additional_criterion) { - std::unordered_set matches; - if (is_singleton_pattern(pattern)) { - for (Node const &graph_node : get_nodes(graph)) { - tl::optional candidate = - get_candidate_singleton_match(pattern, graph, graph_node); - if (candidate.has_value() || - pattern_matches(pattern, graph, candidate.value())) { - matches.insert(candidate.value()); - } +optional satisfies(Operator const ¶ms, + OperatorAttributeConstraint const &constraint) { + auto value = evaluate_attribute_expr(params, constraint.attribute_expr); + return satisfies( + constraint.constraint_type, constraint.attribute_value, value); +} + +template +optional optional_all_of(Container const &container, + Function const &func) { + for (auto const &element : container) { + optional condition = func(element); + if (!condition.has_value()) { + return nullopt; } - } else { - GraphSplit split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto prefix_matches = - find_pattern_matches(subpatterns.first, graph, additional_criterion); - auto postfix_matches = - find_pattern_matches(subpatterns.first, graph, additional_criterion); - auto edge_splits = get_edge_splits(pattern, split); - for (DiGraphPatternMatch const &prefix_match : prefix_matches) { - for (DiGraphPatternMatch const &postfix_match : postfix_matches) { - tl::optional unsplit = - unsplit_matches(prefix_match, postfix_match, edge_splits); - if (unsplit.has_value()) { - matches.insert(unsplit.value()); - } - } + + if (!condition.value()) { + return false; } } + return true; +} - return matches; +optional satisfies(Operator const ¶ms, + OperatorPattern const &pattern) { + return optional_all_of(pattern.attribute_constraints, + [&](OperatorAttributeConstraint const &c) { + return satisfies(params, c); + }); } +optional satisfies(ParallelTensor const ¶ms, + ParallelTensorPattern const &pattern) { + return optional_all_of( + pattern.attribute_constraints, + [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); +} + +bool assignment_satisfies(ParallelComputationGraph const &pcg, + GraphPattern const &pattern, + DiGraphPatternMatch const &patternMatch) { + bool result = true; + for (auto const &kv : patternMatch.nodeAssignment) { + auto patternNode = kv.first; + auto pcgNode = kv.second; + optional constraintResult = + satisfies(pcg.at(pcgNode), pattern.at(patternNode)); + result &= constraintResult.value_or(false); + } + + for (auto const &kv : patternMatch.edgeAssignment) { + auto patternEdge = kv.first; + auto pcgEdge = kv.second; + optional constraintResult = + satisfies(pcg.at(pcgEdge), pattern.at(patternEdge)); + result &= constraintResult.value_or(false); + } + + result &= pattern_matches( + OpenMultiDiGraphView(pattern), MultiDiGraphView(pcg), patternMatch); + + return result; +} } // namespace FlexFlow diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc new file mode 100644 index 0000000000..a5c185aba0 --- /dev/null +++ b/lib/substitutions/src/graph_pattern_match.cc @@ -0,0 +1,261 @@ +#include "substitutions/graph_pattern.h" +#include "utils/hash-utils.h" +#include + +namespace FlexFlow { + +// DiGraphPatternMatch narrow_match(DiGraphPatternMatch const &match, +// OpenMultiDiGraphView const &pattern) { +// DiGraphPatternMatch result; +// std::unordered_set nodes = get_nodes(pattern); +// for (auto const &kv : match.nodeAssignment) { +// Node pattern_node = kv.first; +// if (contains(nodes, pattern_node)) { +// result.nodeAssignment.equate(kv.first, kv.second); +// } +// } + +// std::unordered_set edges = get_edges(pattern); +// for (auto const &kv : match.edgeAssignment) { +// OpenMultiDiEdge pattern_edge = kv.first; +// if (contains(edges, pattern_edge)) { +// result.edgeAssignment.equate(kv.first, kv.second); +// } +// } + +// return result; +// } + +GraphSplit split_pattern(OpenMultiDiGraphView const &pattern) { + std::vector topological_ordering = get_topological_ordering(pattern); + assert(topological_ordering.size() >= 2); + + int split_point = topological_ordering.size() / 2; + auto split = vector_split(topological_ordering, split_point); + std::unordered_set prefix(split.first.begin(), split.first.end()); + std::unordered_set postfix(split.second.begin(), split.second.end()); + return {prefix, postfix}; +} + +std::pair + apply_split(OpenMultiDiGraphView const &pattern, GraphSplit const &split) { + return {get_subgraph(pattern, split.first), + get_subgraph(pattern, split.second)}; +} + +/* +Given a match and a pattern split, gets the submatches in subpatterns. +*/ +MatchSplit apply_split(OpenMultiDiGraphView const &pattern, + DiGraphPatternMatch const &match, + GraphSplit const &split) { + auto prefix = split.first; + auto postfix = split.second; + + MatchSplit result; + + for (auto const &kv : match.nodeAssignment) { + Node pattern_node = kv.first; + Node graph_node = kv.second; + if (contains(split.first, pattern_node)) { + result.prefix_submatch.nodeAssignment.equate(pattern_node, graph_node); + } else { + assert(contains(split.second, pattern_node)); + result.postfix_submatch.nodeAssignment.equate(pattern_node, graph_node); + } + } + + auto edge_splits = get_edge_splits(pattern, split); + + std::function handle_edge = + [&](OpenMultiDiEdge const &pattern_edge) -> void { + MultiDiEdge graph_edge = match.edgeAssignment.at_l(pattern_edge); + auto edge_nodes = get_nodes(pattern_edge); + if (is_subseteq_of(edge_nodes, prefix)) { + result.prefix_submatch.edgeAssignment.equate(pattern_edge, graph_edge); + } else if (is_subseteq_of(edge_nodes, postfix)) { + result.postfix_submatch.edgeAssignment.equate(pattern_edge, graph_edge); + } else { + assert(is_standard_edge(pattern_edge)); + auto standard_edge = mpark::get(pattern_edge); + auto divided = edge_splits.at_l(standard_edge); + handle_edge(divided.first); + handle_edge(divided.second); + } + }; + + for (auto const &kv : match.edgeAssignment) { + OpenMultiDiEdge pattern_edge = kv.first; + handle_edge(pattern_edge); + } + + return result; +} + +bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { + return num_nodes(pattern) == 1; +} + +template +bool pattern_matches(OpenMultiDiGraphView const &pattern, + MultiDiGraphView const &graph, + DiGraphPatternMatch const &match, + F const &additional_criterion) { + if (is_singleton_pattern(pattern)) { + Node pattern_node = get_only(get_nodes(pattern)); + Node graph_matched_node = match.nodeAssignment.at_l(pattern_node); + if (!additional_criterion(pattern_node, graph_matched_node)) { + return false; + } + for (OpenMultiDiEdge const &e : get_edges(pattern)) { + MultiDiEdge graph_matched_edge = match.edgeAssignment.at_l(e); + + assert(is_input_edge(e) || is_output_edge(e)); + if (is_input_edge(e)) { + InputMultiDiEdge input_edge = mpark::get(e); + if (match.nodeAssignment.at_l(input_edge.dst) != + graph_matched_edge.dst || + input_edge.dstIdx != graph_matched_edge.dstIdx) { + return false; + } + } else { + OutputMultiDiEdge output_edge = mpark::get(e); + if (match.nodeAssignment.at_l(output_edge.src) != + graph_matched_edge.src || + output_edge.srcIdx != graph_matched_edge.srcIdx) { + return false; + } + } + + if (!additional_criterion(e, graph_matched_edge)) { + return false; + } + } + + return true; + } + + auto split = split_pattern(pattern); + auto subpatterns = apply_split(pattern, split); + auto submatches = apply_split(pattern, match, split); + + return pattern_matches(subpatterns.first, + graph, + submatches.prefix_submatch, + additional_criterion) && + pattern_matches(subpatterns.second, + graph, + submatches.postfix_submatch, + additional_criterion); +} + +optional + get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, + MultiDiGraphView const &graph, + Node const &graph_node) { + assert(is_singleton_pattern(pattern)); + + Node pattern_node = get_only(get_nodes(pattern)); + + DiGraphPatternMatch match; + match.nodeAssignment.equate(pattern_node, graph_node); + + auto incoming = get_incoming_edges_by_idx(graph, graph_node); + auto outgoing = get_outgoing_edges_by_idx(graph, graph_node); + for (OpenMultiDiEdge const &pattern_edge : get_edges(pattern)) { + assert(is_input_edge(pattern_edge) || is_output_edge(pattern_edge)); + if (is_input_edge(pattern_edge)) { + InputMultiDiEdge input_edge = mpark::get(pattern_edge); + if (!contains_key(incoming, input_edge.dstIdx)) { + return nullopt; + } + match.edgeAssignment.equate(input_edge, + get_only(incoming.at(input_edge.dstIdx))); + } else { + OutputMultiDiEdge output_edge = + mpark::get(pattern_edge); + if (!contains_key(outgoing, output_edge.srcIdx)) { + return nullopt; + } + match.edgeAssignment.equate(output_edge, + get_only(outgoing.at(output_edge.srcIdx))); + } + } + + return match; +} + +optional unsplit_matches( + DiGraphPatternMatch const &prefix, + DiGraphPatternMatch const &postfix, + bidict> const + &edge_splits) { + DiGraphPatternMatch result; + std::unordered_set handled; + for (auto const &kv : edge_splits) { + MultiDiEdge standard_edge = kv.first; + OutputMultiDiEdge output_edge = kv.second.first; + InputMultiDiEdge input_edge = kv.second.second; + handled.insert(output_edge); + handled.insert(input_edge); + + MultiDiEdge output_graph_edge = prefix.edgeAssignment.at_l(output_edge); + MultiDiEdge input_graph_edge = postfix.edgeAssignment.at_l(input_edge); + if (output_graph_edge == input_graph_edge) { + result.edgeAssignment.equate(standard_edge, output_graph_edge); + } else { + return nullopt; + } + } + + for (auto const &kv : + merge_maps(prefix.edgeAssignment, postfix.edgeAssignment)) { + if (!contains(handled, kv.first)) { + result.edgeAssignment.equate(kv.first, kv.second); + } + } + + result.nodeAssignment = + merge_maps(prefix.nodeAssignment, postfix.nodeAssignment); + + return result; +} + +template +std::unordered_set + find_pattern_matches(OpenMultiDiGraphView const &pattern, + MultiDiGraphView const &graph, + F const &additional_criterion) { + std::unordered_set matches; + if (is_singleton_pattern(pattern)) { + for (Node const &graph_node : get_nodes(graph)) { + optional candidate = + get_candidate_singleton_match(pattern, graph, graph_node); + if (candidate.has_value() || + pattern_matches(pattern, graph, candidate.value())) { + matches.insert(candidate.value()); + } + } + } else { + GraphSplit split = split_pattern(pattern); + auto subpatterns = apply_split(pattern, split); + auto prefix_matches = + find_pattern_matches(subpatterns.first, graph, additional_criterion); + auto postfix_matches = + find_pattern_matches(subpatterns.first, graph, additional_criterion); + auto edge_splits = get_edge_splits(pattern, split); + for (DiGraphPatternMatch const &prefix_match : prefix_matches) { + for (DiGraphPatternMatch const &postfix_match : postfix_matches) { + optional unsplit = + unsplit_matches(prefix_match, postfix_match, edge_splits); + if (unsplit.has_value()) { + matches.insert(unsplit.value()); + } + } + } + } + + return matches; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index 84b5139e10..3adb3393ef 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -1,67 +1,65 @@ #include "substitutions/get_attribute.h" -#include "substitutions/substitutions_v2.h" namespace FlexFlow { -namespace substitutions { -tl::optional get_attribute(AggregateAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(AggregateAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(AggregateSpecAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(AggregateSpecAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(BatchMatmulAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(CastAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(CastAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.dtype; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(CombineAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.combine_legion_dim; case OperatorAttributeKey::PARALLEL_DIM: return p.combine_degree; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(ConcatAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(Conv2DAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -82,38 +80,38 @@ tl::optional get_attribute(Conv2DAttrs const &p, case OperatorAttributeKey::USE_BIAS: return p.use_bias; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(ElementBinaryAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(ElementUnaryAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::SCALAR: return p.scalar; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(DropoutAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(EmbeddingAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.data_type; @@ -124,46 +122,46 @@ tl::optional get_attribute(EmbeddingAttrs const &p, case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(FlatAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(GatherAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.legion_dim; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(Group_byAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(Group_byAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(LayerNormAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(LinearAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; @@ -174,24 +172,24 @@ tl::optional get_attribute(LinearAttrs const &p, case OperatorAttributeKey::ACTIVATION: return p.activation; default: - return tl::nullopt; + return nullopt; } } -tl::optional - get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { +optional get_attribute(MultiHeadAttentionAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::NUM_HEADS: return p.num_heads; case OperatorAttributeKey::USE_BIAS: return p.bias; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(Pool2DAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -210,105 +208,105 @@ tl::optional get_attribute(Pool2DAttrs const &p, case OperatorAttributeKey::ACTIVATION: return p.activation; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(ReduceAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(ReductionAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.reduction_legion_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.reduction_degree; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(RepartitionAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.repartition_legion_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.repartition_degree; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(ReplicateAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.replicate_legion_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.replicate_degree; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(ReshapeAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(SplitAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.legion_axis; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(SoftmaxAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(TopKAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } -tl::optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey key) { +optional get_attribute(TransposeAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PERMUTATION: return p.perm; default: - return tl::nullopt; + return nullopt; } } -tl::optional - get_attribute(FusedParallelOpAttrs const &p, OperatorAttributeKey key) { +optional get_attribute(FusedParallelOpAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return tl::nullopt; + return nullopt; } } @@ -316,7 +314,7 @@ struct GetAttribute { GetAttribute(OperatorAttributeKey key) : key(key) {} template - tl::optional operator()(T const &t) { + optional operator()(T const &t) { return get_attribute(t, this->key); } @@ -324,10 +322,9 @@ struct GetAttribute { OperatorAttributeKey key; }; -tl::optional get_attribute(PCGOperatorAttrs const &p, - OperatorAttributeKey key) { - return mpark::visit(GetAttribute(key), p); +optional get_attribute(PCGOperatorAttrs const &p, + OperatorAttributeKey key) { + return visit(GetAttribute(key), p); } -} // namespace substitutions } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions.cc b/lib/substitutions/src/substitutions.cc deleted file mode 100644 index cb44684148..0000000000 --- a/lib/substitutions/src/substitutions.cc +++ /dev/null @@ -1,212 +0,0 @@ -#include "substitutions/substitutions.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "substitutions/get_attribute.h" - -namespace FlexFlow { - -bool satisfies(Operator const &, - std::vector const &, - OperatorAttributeConstraint const &); - -template -optional evaluate_list_index_access(ListIndexAccess const &index_access, - optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; - } - - auto vec = get>(v.value()); - if (index_access.index >= vec.size()) { - return nullopt; - } - - return vec.at(index_access.index); -} - -template -optional evaluate_list_size(optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; - } - - return (int)get>(v.value()).size(); -} - -struct EvaluateOperatorAttributeExpr { - EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - - optional operator()(OperatorAttributeKey key) { - return get_attribute(this->attrs, key); - } - - optional - operator()(ListIndexAccess const &index_access) { - optional v = - get_attribute(this->attrs, index_access.attribute_key); - return evaluate_list_index_access(index_access, v); - } - - optional - operator()(ListSize const &list_size) { - optional v = - get_attribute(this->attrs, list_size.attribute_key); - return evaluate_list_size(v); - } - -private: - Operator attrs; -}; - -optional - evaluate_tensor_attribute_expr(ParallelTensor const &, - AttributeExpr const &); - -struct EvaluateTensorAttributeExpr { - EvaluateTensorAttributeExpr(ParallelTensor const &tensor_shape) - : tensor_shape(tensor_shape) {} - - template - optional evaluate(T const &t) { - return this->operator()(t); - } - - optional operator()(TensorAttributeKey key) { - switch (key) { - case TensorAttributeKey::DIM_SIZES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape) { - result.push_back(dim.size); - } - return result; - } - case TensorAttributeKey::DIM_DEGREES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape) { - result.push_back(dim.degree); - } - return result; - } - default: - throw std::runtime_error("Unknown TensorAttributeKey"); - } - } - - optional - operator()(ListIndexAccess const &index_access) { - auto v = this->evaluate(index_access.attribute_key); - return evaluate_list_index_access(index_access, v); - } - - optional - operator()(ListSize const &list_size) { - return evaluate_list_size(this->evaluate(list_size.attribute_key)); - } - -private: - ParallelTensor tensor_shape; -}; - -optional - evaluate_attribute_expr(ParallelTensor const &tensor_shape, - AttributeExpr const &expr) { - return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); -} - -optional - evaluate_attribute_expr(Operator const &attrs, - AttributeExpr const &expr) { - return visit(EvaluateOperatorAttributeExpr(attrs), expr); -} - -template -optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - optional const &maybe_attribute_value) { - if (!maybe_attribute_value.has_value()) { - return nullopt; - } - V attr_val = maybe_attribute_value.value(); - - if (attr_val.index() != constraint_value.index()) { - return nullopt; - } - - if (constraint_type == ConstraintType::EQUAL) { - return attr_val == constraint_value; - } else { - throw std::runtime_error("Unknown constraint_type"); - } -} - -optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(params, constraint.attribute_expr); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -template -optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - optional condition = func(element); - if (!condition.has_value()) { - return nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { - return optional_all_of(pattern.attribute_constraints, - [&](OperatorAttributeConstraint const &c) { - return satisfies(params, c); - }); -} - -optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { - return optional_all_of( - pattern.attribute_constraints, - [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); -} - -bool assignment_satisfies(ParallelComputationGraph const &pcg, - SubstitutionPattern const &pattern, - DiGraphPatternMatch const &patternMatch) { - bool result = true; - for (auto const &kv : patternMatch.nodeAssignment) { - auto patternNode = kv.first; - auto pcgNode = kv.second; - optional constraintResult = - satisfies(pcg.at(pcgNode), pattern.at(patternNode)); - result &= constraintResult.value_or(false); - } - - for (auto const &kv : patternMatch.edgeAssignment) { - auto patternEdge = kv.first; - auto pcgEdge = kv.second; - optional constraintResult = - satisfies(pcg.at(pcgEdge), pattern.at(patternEdge)); - result &= constraintResult.value_or(false); - } - - result &= pattern_matches( - OpenMultiDiGraphView(pattern), MultiDiGraphView(pcg), patternMatch); - - return result; -} -} // namespace FlexFlow From fc807b4b2c7818a05a4d0e4220f25ef712b68265 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 20 Aug 2023 23:30:14 -0400 Subject: [PATCH 20/61] format --- .../substitutions/attribute_expr_binding.h | 3 ++- .../include/substitutions/operator_pattern.h | 5 +++-- .../include/substitutions/output_graph.h | 21 +++++++++---------- .../substitutions/parallel_tensor_pattern.h | 2 +- .../include/substitutions/substitution.h | 6 ++++-- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/lib/substitutions/include/substitutions/attribute_expr_binding.h b/lib/substitutions/include/substitutions/attribute_expr_binding.h index f719cfaf9d..831e32b6a9 100644 --- a/lib/substitutions/include/substitutions/attribute_expr_binding.h +++ b/lib/substitutions/include/substitutions/attribute_expr_binding.h @@ -18,10 +18,11 @@ struct AttributeExprBinding { GraphAttributeExpr get_expr(attr_expr_id const &id) const { return binding.at(id); } + private: std::unordered_map> binding; }; -} +} // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 288f01970b..11d5659bf3 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -39,7 +39,8 @@ enum class OperatorAttributeKey { PAD, }; -using OperatorAttributeValue = variant, OperatorType, Activation>; +using OperatorAttributeValue = + variant, OperatorType, Activation>; using OperatorAttributeConstraint = AttributeConstraint; @@ -48,6 +49,6 @@ struct OperatorPattern { std::unordered_set attribute_constraints; }; -} +} // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index 146c3002cf..3e6cf52f6b 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -18,12 +18,7 @@ struct EdgeAttrAccess { T attr_expr; }; -enum class AttrBinaryOpType { - ADD, - SUB, - MUL, - DIV -}; +enum class AttrBinaryOpType { ADD, SUB, MUL, DIV }; template struct AttrBinary { @@ -38,12 +33,15 @@ using GraphAttributeExpr = variant, EdgeAttrAccess>; template using GraphAttributeExpr = AttrBinary; -using GraphAttributeValue = variant, OperatorType, Activation>; +using GraphAttributeValue = + variant, OperatorType, Activation>; -// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can define the assignment for each operator type. +// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can +// define the assignment for each operator type. template struct OperatorAttrAssignment { - std::vector>> assignment; + std::vector>> + assignment; }; template @@ -51,10 +49,11 @@ struct ParallelTensorAttrAssignment { std::vector>> assignment; }; -struct OutputGraph : public strong_typedef { +struct OutputGraph : public strong_typedef { using strong_typedef::strong_typedef; }; -} +} // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index f04abfd441..9a1aac0603 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -27,6 +27,6 @@ struct ParallelTensorPattern { std::unordered_set attribute_constraints; }; -} +} // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index ae68063d35..c69d65ca36 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -11,8 +11,10 @@ struct Substitution { OutputGraph output_graph; }; -ParallelComputationGraph apply_substitution(ParallelComputationGraph const &, Substitution const &, DiGraphPatternMatch const &); +ParallelComputationGraph apply_substitution(ParallelComputationGraph const &, + Substitution const &, + DiGraphPatternMatch const &); -} +} // namespace FlexFlow #endif From 55e8de37f4efe322f9d673b52cf106d09f9a9056 Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 21 Aug 2023 16:13:47 -0400 Subject: [PATCH 21/61] further draft substitution --- .../include/substitutions/operator_pattern.h | 4 + .../include/substitutions/output_graph.h | 27 +++- .../substitutions/parallel_tensor_pattern.h | 4 + .../include/substitutions/substitution.h | 2 + lib/substitutions/src/graph_pattern.cc | 3 +- lib/substitutions/src/substitution.cc | 116 ++++++++++++++++++ 6 files changed, 149 insertions(+), 7 deletions(-) create mode 100644 lib/substitutions/src/substitution.cc diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 11d5659bf3..bba237fc47 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -49,6 +49,10 @@ struct OperatorPattern { std::unordered_set attribute_constraints; }; +optional + evaluate_attribute_expr(Operator const &attrs, + AttributeExpr const &expr); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index 3e6cf52f6b..617facef05 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -18,20 +18,32 @@ struct EdgeAttrAccess { T attr_expr; }; -enum class AttrBinaryOpType { ADD, SUB, MUL, DIV }; +enum class AttrOpType { ADD, SUB, MUL, DIV }; + +template +struct AttrConstant { + T value; +}; +template +struct AttrUnary { + AttrOpType op_type; + L lhs; + R rhs; +}; template struct AttrBinary { - AttrBinaryOpType op_type; + AttrOpType op_type; L lhs; R rhs; }; template -using GraphAttributeExpr = variant, EdgeAttrAccess>; +using GraphAttributeExpr = + variant, EdgeAttrAccess, AttrConstant>; template -using GraphAttributeExpr = AttrBinary; +using GraphAttributeExpr = variant, AttrBinary>; using GraphAttributeValue = variant, OperatorType, Activation>; @@ -49,8 +61,11 @@ struct ParallelTensorAttrAssignment { std::vector>> assignment; }; -struct OutputGraph : public strong_typedef { +struct OutputGraph + : public strong_typedef< + OutputGraph, + OutputLabelledMultiDiGraph> { using strong_typedef::strong_typedef; }; diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index 9a1aac0603..f7a341b6a2 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -27,6 +27,10 @@ struct ParallelTensorPattern { std::unordered_set attribute_constraints; }; +optional + evaluate_attribute_expr(ParallelTensor const &tensor_shape, + AttributeExpr const &expr); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index c69d65ca36..2ab1dc998a 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -9,6 +9,8 @@ namespace FlexFlow { struct Substitution { GraphPattern input_graph; OutputGraph output_graph; + bidict input_mapping; + bidict output_mapping; }; ParallelComputationGraph apply_substitution(ParallelComputationGraph const &, diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 59c9ed6b0b..8dd16c393a 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -1,7 +1,8 @@ #include "op-attrs/operator_attrs.h" #include "op-attrs/parallel_tensor_shape.h" #include "substitutions/get_attribute.h" -#include "substitutions/substitutions.h" +#include "substitutions/operator_pattern.h" +#include "substitutions/parallel_tensor_pattern.h" namespace FlexFlow { diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc new file mode 100644 index 0000000000..64ac7feffd --- /dev/null +++ b/lib/substitutions/src/substitution.cc @@ -0,0 +1,116 @@ +#include "substitutions/substitution.h" + +namespace FlexFlow { + +template +GraphAttributeValue graph_attribute_value_op(AttrOpType op, T const &lhs, T const &rhs) { + switch (op) { + case AttrOpType::ADD: + return lhs + rhs; + break; + case AttrOpType::SUB: + return lhs - rhs; + break; + case AttrOpType::MUL: + return lhs * rhs; + break; + case AttrOpType::DIV: + return lhs / rhs; + break; + default: + mk_runtime_error("Unknown attribute operator type"); + } +} + +struct EvaluateGraphAttributeExpr { + template + GraphAttributeValue operator()(Ts... const &ts) { + return evaluate(ts); + } + + template + GraphAttributeValue evaluate(NodeAttrAccess const &t) { + Node node_in_pattern = t.node; + Node node_in_pcg = match.nodeAssignment.at_l(node_in_pattern); + return evaluate_attribute_expr(node_in_pcg, t.attr_expr); + } + + template + GraphAttributeValue evaluate(EdgeAttrAccess const &t) { + OpenMultiDiEdge edge_in_pattern = t.edge; + MultiDiEdge edge_in_pcg = match.edgeAssignment.at_l(edge_in_pattern); + return evaluate_attribute_expr(edge_in_pcg, t.attr_expr); + } + + template + GraphAttributeValue evaluate(AttrUnary const &t) { + auto lhs = (*this)(t.lhs).value(); + auto rhs = t.rhs; + return graph_attribute_value_op(lhs, rhs); + } + + template + GraphAttributeValue evaluate(AttrBinary const &t) { + auto lhs = (*this)(t.lhs).value(); + auto rhs = (*this)(t.rhs).value(); + return graph_attribute_value_op(lhs, rhs); + } + + EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match) : graph(graph), match(match) {} + + ParallelComputationGraph const &graph; + DiGraphPatternMatch const &match; +}; + +template +GraphAttributeValue evaluate_graph_attribute_expr(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match, GraphAttributeExpr const &expr) { + return visit(EvaluateGraphAttributeExpr(graph, match), expr); +} + +Operator get_operator_attrs(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match, OperatorAttrAssignment const &assignment) { + NOT_IMPLEMENTED(); +} + +ParallelTensor get_parallel_tensor_attrs(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match, ParallelTensorAttrAssignment const &assignment) { + NOT_IMPLEMENTED(); +} + +ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, + Substitution const &substitution, + DiGraphPatternMatch const &match) { + ParallelComputationGraph new_pcg = ParallelComputationGraph::create(); + bidict node_mapping; // Refactor it with global nodes + for (Node const &node : get_nodes(pcg)) { + if (!contains_r(match.nodeAssignment)) { + node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); + } + } + for (MultiDiEdge const &edge : get_edges(pcg)) { + if (!contains_r(match.edgeAssignment)) { + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(edge.src), node_mapping.at_r(edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + } + } + for (Node const &output_node : get_nodes(substitution.output_graph)) { + Node new_node = new_pcg.add_node(get_operator_attrs(pcg, match, substitution.output_graph.at(output_node))); + node_mapping.equate(output_node, new_node); + } + for (OpenMultiDiEdge const &output_edge : get_edges(substitution.output_graph)) { + if (holds_alternative(output_edge)) { + MultiDiEdge origin_edge = match.edgeAssignment.at_r(substitution.input_mapping.at_r(output_edge)); + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(origin_edge.src), node_mapping.at_l(output_edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + } else if (holds_alternative(output_edge)) { + MultiDiEdge origin_edge = match.edgeAssignment.at_r(substitution.output_mapping.at_r(output_edge)); + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src), node_mapping.at_l(origin_edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + } else { + assert(holds_alternative(output_edge)); + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src), node_mapping.at_l(output_edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + } + } + for (MultiDiOutput const &output : get_outputs(substitution.output_graph)) { + new_pcg.add_output(MultiDiOutput{node_mapping.at_l(output.src), new_pcg.add_node_port()}, get_parallel_tensor_attrs(pcg, match, substitution.output_graph.at(output))); + } + + return new_pcg; +} + +} \ No newline at end of file From 4f0e4d271657c749f8f6c37bfd3105bbb20fc50d Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 21 Aug 2023 17:50:35 -0400 Subject: [PATCH 22/61] format --- .../include/substitutions/output_graph.h | 2 +- lib/substitutions/src/substitution.cc | 81 +++++++++++++------ 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index 617facef05..73231210fc 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -65,7 +65,7 @@ struct OutputGraph : public strong_typedef< OutputGraph, OutputLabelledMultiDiGraph> { + ParallelTensorAttrAssignment>> { using strong_typedef::strong_typedef; }; diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 64ac7feffd..94c12af9b3 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -3,22 +3,23 @@ namespace FlexFlow { template -GraphAttributeValue graph_attribute_value_op(AttrOpType op, T const &lhs, T const &rhs) { +GraphAttributeValue + graph_attribute_value_op(AttrOpType op, T const &lhs, T const &rhs) { switch (op) { case AttrOpType::ADD: - return lhs + rhs; - break; + return lhs + rhs; + break; case AttrOpType::SUB: - return lhs - rhs; - break; + return lhs - rhs; + break; case AttrOpType::MUL: - return lhs * rhs; - break; + return lhs * rhs; + break; case AttrOpType::DIV: - return lhs / rhs; - break; + return lhs / rhs; + break; default: - mk_runtime_error("Unknown attribute operator type"); + mk_runtime_error("Unknown attribute operator type"); } } @@ -56,29 +57,40 @@ struct EvaluateGraphAttributeExpr { return graph_attribute_value_op(lhs, rhs); } - EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match) : graph(graph), match(match) {} + EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph, + DiGraphPatternMatch const &match) + : graph(graph), match(match) {} ParallelComputationGraph const &graph; DiGraphPatternMatch const &match; }; template -GraphAttributeValue evaluate_graph_attribute_expr(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match, GraphAttributeExpr const &expr) { +GraphAttributeValue + evaluate_graph_attribute_expr(ParallelComputationGraph const &graph, + DiGraphPatternMatch const &match, + GraphAttributeExpr const &expr) { return visit(EvaluateGraphAttributeExpr(graph, match), expr); } -Operator get_operator_attrs(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match, OperatorAttrAssignment const &assignment) { +Operator get_operator_attrs(ParallelComputationGraph const &graph, + DiGraphPatternMatch const &match, + OperatorAttrAssignment const &assignment) { NOT_IMPLEMENTED(); } -ParallelTensor get_parallel_tensor_attrs(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match, ParallelTensorAttrAssignment const &assignment) { +ParallelTensor + get_parallel_tensor_attrs(ParallelComputationGraph const &graph, + DiGraphPatternMatch const &match, + ParallelTensorAttrAssignment const &assignment) { NOT_IMPLEMENTED(); } ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, Substitution const &substitution, DiGraphPatternMatch const &match) { - ParallelComputationGraph new_pcg = ParallelComputationGraph::create(); + ParallelComputationGraph new_pcg = + ParallelComputationGraph::create(); bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { if (!contains_r(match.nodeAssignment)) { @@ -87,30 +99,49 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, } for (MultiDiEdge const &edge : get_edges(pcg)) { if (!contains_r(match.edgeAssignment)) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(edge.src), node_mapping.at_r(edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(edge.src), + node_mapping.at_r(edge.dst), + new_pcg.add_node_port(), + new_pcg.add_node_port()}); } } for (Node const &output_node : get_nodes(substitution.output_graph)) { - Node new_node = new_pcg.add_node(get_operator_attrs(pcg, match, substitution.output_graph.at(output_node))); + Node new_node = new_pcg.add_node(get_operator_attrs( + pcg, match, substitution.output_graph.at(output_node))); node_mapping.equate(output_node, new_node); } - for (OpenMultiDiEdge const &output_edge : get_edges(substitution.output_graph)) { + for (OpenMultiDiEdge const &output_edge : + get_edges(substitution.output_graph)) { if (holds_alternative(output_edge)) { - MultiDiEdge origin_edge = match.edgeAssignment.at_r(substitution.input_mapping.at_r(output_edge)); - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(origin_edge.src), node_mapping.at_l(output_edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + MultiDiEdge origin_edge = match.edgeAssignment.at_r( + substitution.input_mapping.at_r(output_edge)); + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(origin_edge.src), + node_mapping.at_l(output_edge.dst), + new_pcg.add_node_port(), + new_pcg.add_node_port()}); } else if (holds_alternative(output_edge)) { - MultiDiEdge origin_edge = match.edgeAssignment.at_r(substitution.output_mapping.at_r(output_edge)); - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src), node_mapping.at_l(origin_edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + MultiDiEdge origin_edge = match.edgeAssignment.at_r( + substitution.output_mapping.at_r(output_edge)); + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src), + node_mapping.at_l(origin_edge.dst), + new_pcg.add_node_port(), + new_pcg.add_node_port()}); } else { assert(holds_alternative(output_edge)); - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src), node_mapping.at_l(output_edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); + new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src), + node_mapping.at_l(output_edge.dst), + new_pcg.add_node_port(), + new_pcg.add_node_port()}); } } for (MultiDiOutput const &output : get_outputs(substitution.output_graph)) { - new_pcg.add_output(MultiDiOutput{node_mapping.at_l(output.src), new_pcg.add_node_port()}, get_parallel_tensor_attrs(pcg, match, substitution.output_graph.at(output))); + new_pcg.add_output( + MultiDiOutput{node_mapping.at_l(output.src), new_pcg.add_node_port()}, + get_parallel_tensor_attrs( + pcg, match, substitution.output_graph.at(output))); } return new_pcg; } -} \ No newline at end of file +} // namespace FlexFlow From cc5837bb9e0544ead9e01f06e9cc6e0af6a9e04b Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 23 Aug 2023 17:34:31 -0400 Subject: [PATCH 23/61] minor fix --- .../include/substitutions/attribute_expr.h | 2 +- .../substitutions/attribute_expr_binding.h | 1 + .../include/substitutions/get_attribute.h | 12 ++--- .../include/substitutions/graph_pattern.h | 3 ++ .../substitutions/graph_pattern_match.h | 14 +++--- .../include/substitutions/operator_pattern.h | 7 ++- .../include/substitutions/output_graph.h | 47 ++++++++----------- .../substitutions/parallel_tensor_pattern.h | 3 +- lib/substitutions/src/graph_pattern.cc | 6 +-- lib/utils/include/utils/graph/algorithms.h | 6 ++- 10 files changed, 50 insertions(+), 51 deletions(-) diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h index 52dd6558af..ee1c6dedd2 100644 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ b/lib/substitutions/include/substitutions/attribute_expr.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H #define _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H -#include "mpark/variant.hpp" +#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/attribute_expr_binding.h b/lib/substitutions/include/substitutions/attribute_expr_binding.h index 831e32b6a9..ff5303d171 100644 --- a/lib/substitutions/include/substitutions/attribute_expr_binding.h +++ b/lib/substitutions/include/substitutions/attribute_expr_binding.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_SUBSTITUTIONS_ATTRIBUTE_EXPR_BINDING_H #include "attribute_expr.h" +#include "utils/strong_typedef.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index 0a4d5b99fd..9bc25947ba 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -1,9 +1,9 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_OPERATOR_ATTRIBUTES_H -#define _FLEXFLOW_SUBSTITUTIONS_OPERATOR_ATTRIBUTES_H +#ifndef _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H +#define _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H #include "op-attrs/operator_attrs.h" -#include "substitutions/substitutions.h" -#include "tl/optional.hpp" +#include "operator_pattern.h" +#include "utils/optional.h" namespace FlexFlow { @@ -67,8 +67,8 @@ optional get_attribute(TopKAttrs const &p, OperatorAttributeKey); optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey); -optional get_attribute(FusedParallelOpAttrs const &p, - OperatorAttributeKey); +// optional get_attribute(FusedParallelOpAttrs const &p, +// OperatorAttributeKey); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 8a2bd015bc..05b6b18053 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -2,6 +2,9 @@ #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H #include "graph_pattern_match.h" +#include "operator_pattern.h" +#include "parallel_tensor_pattern.h" +#include "pcg/parallel_computation_graph.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h index 2e7ddd852c..449c26c846 100644 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/graph_pattern_match.h @@ -1,24 +1,22 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_H +#ifndef _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H +#define _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H #include "utils/graph.h" +#include "utils/visitable.h" namespace FlexFlow { struct DiGraphPatternMatch { bidict nodeAssignment; - req> edgeAssignment; + bidict edgeAssignment; }; -FF_VISITABLE_STRUCT(DiGraphPatternMatch, nodeAssignment, edgeAssignment); - struct MatchSplit { DiGraphPatternMatch prefix_submatch; - req postfix_submatch; + DiGraphPatternMatch postfix_submatch; }; -FF_VISITABLE_STRUCT(MatchSplit, prefix_submatch, postfix_submatch); - +template bool pattern_matches(OpenMultiDiGraphView const &, MultiDiGraphView const &, DiGraphPatternMatch const &, diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index bba237fc47..b2dda4f638 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -1,7 +1,12 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H #define _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H -#include "constraint.h" +#include "attribute_expr.h" +#include "op-attrs/activation.h" +#include "op-attrs/op.h" +#include "pcg/operator.h" +#include +#include namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index 73231210fc..b9db236390 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -5,60 +5,53 @@ namespace FlexFlow { +using GraphAttributeKey = variant; +using GraphAttributeValue = + variant, OperatorType, Activation>; + // NOTE(@wmdi) I am not sure whether these should be part of attribute expr. -template struct NodeAttrAccess { Node node; - T attr_expr; + GraphAttributeKey attr_expr; }; -template struct EdgeAttrAccess { OpenMultiDiEdge edge; - T attr_expr; + GraphAttributeKey attr_expr; }; -enum class AttrOpType { ADD, SUB, MUL, DIV }; - -template struct AttrConstant { - T value; + GraphAttributeValue value; }; -template + +using GraphAttributeExprLeaf = + variant; + +enum class AttrOpType { ADD, SUB, MUL, DIV }; + struct AttrUnary { AttrOpType op_type; - L lhs; - R rhs; + GraphAttributeExprLeaf lhs; + GraphAttributeExprLeaf rhs; }; -template struct AttrBinary { AttrOpType op_type; - L lhs; - R rhs; + GraphAttributeExprLeaf lhs; + GraphAttributeExprLeaf rhs; }; -template using GraphAttributeExpr = - variant, EdgeAttrAccess, AttrConstant>; - -template -using GraphAttributeExpr = variant, AttrBinary>; - -using GraphAttributeValue = - variant, OperatorType, Activation>; + variant; // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. -template struct OperatorAttrAssignment { - std::vector>> - assignment; + std::vector> assignment; }; -template struct ParallelTensorAttrAssignment { - std::vector>> assignment; + std::vector> assignment; }; struct OutputGraph diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index f7a341b6a2..c62237d0fd 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -1,7 +1,8 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H #define _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H -#include "constraint.h" +#include "attribute_expr.h" +#include "pcg/parallel_tensor.h" namespace FlexFlow { diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 8dd16c393a..8d8a8d1d52 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -6,10 +6,6 @@ namespace FlexFlow { -bool satisfies(Operator const &, - std::vector const &, - OperatorAttributeConstraint const &); - template optional evaluate_list_index_access(ListIndexAccess const &index_access, optional const &v) { @@ -37,7 +33,7 @@ optional evaluate_list_size(optional const &v) { struct EvaluateOperatorAttributeExpr { EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - optional operator()(OperatorAttributeKey key) { + optional operator()(OperatorAttributeKey const &key) { return get_attribute(this->attrs, key); } diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index e08ef30ccd..c83fcf0fdb 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -20,6 +20,7 @@ namespace FlexFlow { std::vector add_nodes(Graph &, int); std::unordered_set get_nodes(GraphView const &); +std::unordered_set get_nodes(OpenMultiDiGraphView const &); std::unordered_set get_node_ports(MultiDiGraphView const &); std::unordered_set get_nodes(OpenMultiDiEdge const &); @@ -66,6 +67,7 @@ UndirectedGraphView apply_contraction(UndirectedGraphView const &, std::unordered_map const &); std::size_t num_nodes(GraphView const &); +std::size_t num_nodes(OpenMultiDiGraphView const &); bool empty(GraphView const &); void add_edges(MultiDiGraph &, std::vector const &); @@ -113,9 +115,9 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &, std::unordered_set get_incoming_edges(DiGraphView const &, std::unordered_set const &); -std::unordered_map> +std::unordered_map> get_incoming_edges_by_idx(MultiDiGraphView const &, Node const &); -std::unordered_map> +std::unordered_map> get_outgoing_edges_by_idx(MultiDiGraphView const &, Node const &); std::unordered_set get_outgoing_edges(MultiDiGraphView const &, From d1aa92f6e9a6517a4fdee551b51f79eb53b52f3e Mon Sep 17 00:00:00 2001 From: wmdi Date: Fri, 25 Aug 2023 16:32:54 -0400 Subject: [PATCH 24/61] refactor the pattern graph to be OutputLabelledOpenMultiDiGraph --- .../include/substitutions/graph_pattern.h | 3 +- .../include/substitutions/output_graph.h | 11 +- .../include/substitutions/substitution.h | 2 +- lib/substitutions/src/graph_pattern.cc | 10 +- lib/substitutions/src/substitution.cc | 109 +++++++++++------- .../graph/labelled/output_labelled_open.h | 17 +++ .../include/utils/graph/labelled_graphs.h | 1 + 7 files changed, 98 insertions(+), 55 deletions(-) create mode 100644 lib/utils/include/utils/graph/labelled/output_labelled_open.h diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 05b6b18053..e2054f1a4f 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -11,7 +11,8 @@ namespace FlexFlow { struct GraphPattern : public strong_typedef< GraphPattern, - LabelledOpenMultiDiGraph> { + OutputLabelledOpenMultiDiGraph> { using strong_typedef::strong_typedef; }; diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index b9db236390..7b32ca9900 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -5,19 +5,18 @@ namespace FlexFlow { -using GraphAttributeKey = variant; using GraphAttributeValue = variant, OperatorType, Activation>; // NOTE(@wmdi) I am not sure whether these should be part of attribute expr. struct NodeAttrAccess { Node node; - GraphAttributeKey attr_expr; + AttributeExpr attr_expr; }; struct EdgeAttrAccess { OpenMultiDiEdge edge; - GraphAttributeKey attr_expr; + AttributeExpr attr_expr; }; struct AttrConstant { @@ -32,7 +31,7 @@ enum class AttrOpType { ADD, SUB, MUL, DIV }; struct AttrUnary { AttrOpType op_type; GraphAttributeExprLeaf lhs; - GraphAttributeExprLeaf rhs; + GraphAttributeValue rhs; }; struct AttrBinary { @@ -57,8 +56,8 @@ struct ParallelTensorAttrAssignment { struct OutputGraph : public strong_typedef< OutputGraph, - OutputLabelledMultiDiGraph> { + OutputLabelledOpenMultiDiGraph> { using strong_typedef::strong_typedef; }; diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 2ab1dc998a..55820da33f 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -9,7 +9,7 @@ namespace FlexFlow { struct Substitution { GraphPattern input_graph; OutputGraph output_graph; - bidict input_mapping; + bidict input_mapping; bidict output_mapping; }; diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 8d8a8d1d52..9b7529ad8a 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -1,6 +1,9 @@ +#include "substitutions/graph_pattern.h" #include "op-attrs/operator_attrs.h" #include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph.h" #include "substitutions/get_attribute.h" +#include "substitutions/graph_pattern_match.h" #include "substitutions/operator_pattern.h" #include "substitutions/parallel_tensor_pattern.h" @@ -72,14 +75,14 @@ struct EvaluateTensorAttributeExpr { switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector result; - for (ParallelDim const &dim : this->tensor_shape) { + for (ParallelDim const &dim : this->tensor_shape.dims) { result.push_back(dim.size); } return result; } case TensorAttributeKey::DIM_DEGREES: { std::vector result; - for (ParallelDim const &dim : this->tensor_shape) { + for (ParallelDim const &dim : this->tensor_shape.dims) { result.push_back(dim.degree); } return result; @@ -201,8 +204,7 @@ bool assignment_satisfies(ParallelComputationGraph const &pcg, result &= constraintResult.value_or(false); } - result &= pattern_matches( - OpenMultiDiGraphView(pattern), MultiDiGraphView(pcg), patternMatch); + result &= pattern_matches(pattern, pcg, patternMatch); return result; } diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 94c12af9b3..c5cc870a7a 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -1,60 +1,84 @@ #include "substitutions/substitution.h" +#include namespace FlexFlow { -template -GraphAttributeValue - graph_attribute_value_op(AttrOpType op, T const &lhs, T const &rhs) { - switch (op) { - case AttrOpType::ADD: - return lhs + rhs; - break; - case AttrOpType::SUB: - return lhs - rhs; - break; - case AttrOpType::MUL: - return lhs * rhs; - break; - case AttrOpType::DIV: - return lhs / rhs; - break; - default: - mk_runtime_error("Unknown attribute operator type"); +struct GraphAttributeValueOp { + template + GraphAttributeValue operator()(T const &lhs, T const &rhs) { + switch (op) { + case AttrOpType::ADD: + return lhs + rhs; + break; + case AttrOpType::SUB: + return lhs - rhs; + break; + case AttrOpType::MUL: + return lhs * rhs; + break; + case AttrOpType::DIV: + return lhs / rhs; + break; + default: + mk_runtime_error("Unknown attribute operator type"); + } } + AttrOpType op; +}; + +GraphAttributeValue graph_attribute_value_op(AttrOpType op, + GraphAttributeValue const &lhs, + GraphAttributeValue const &rhs) { + visit(GraphAttributeValueOp{op}, lhs, rhs); } -struct EvaluateGraphAttributeExpr { - template - GraphAttributeValue operator()(Ts... const &ts) { - return evaluate(ts); +struct EvaluateGraphAttributeExprLeaf { + template + GraphAttributeValue operator()(T const &t) { + return evaluate(t); } - template - GraphAttributeValue evaluate(NodeAttrAccess const &t) { + GraphAttributeValue evaluate(NodeAttrAccess const &t) { Node node_in_pattern = t.node; Node node_in_pcg = match.nodeAssignment.at_l(node_in_pattern); - return evaluate_attribute_expr(node_in_pcg, t.attr_expr); + return widen( + evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value()); + } + + GraphAttributeValue evaluate(EdgeAttrAccess const &t) { + OpenMultiDiEdge output_in_pattern = t.edge; + MultiDiEdge output_in_pcg = match.edgeAssignment.at_l(output_in_pattern); + return widen( + evaluate_attribute_expr(graph.at(output_in_pcg), t.attr_expr).value()); } + ParallelComputationGraph const &graph; + DiGraphPatternMatch const &match; +}; + +GraphAttributeValue + evaluate_graph_attribute_expr_leaf(ParallelComputationGraph const &g, + DiGraphPatternMatch const &match, + GraphAttributeExprLeaf const &expr) { + return visit(EvaluateGraphAttributeExprLeaf{g, match}, expr); +} + +struct EvaluateGraphAttributeExpr { template - GraphAttributeValue evaluate(EdgeAttrAccess const &t) { - OpenMultiDiEdge edge_in_pattern = t.edge; - MultiDiEdge edge_in_pcg = match.edgeAssignment.at_l(edge_in_pattern); - return evaluate_attribute_expr(edge_in_pcg, t.attr_expr); + GraphAttributeValue operator()(T const &t) { + return evaluate(t); } - template - GraphAttributeValue evaluate(AttrUnary const &t) { - auto lhs = (*this)(t.lhs).value(); - auto rhs = t.rhs; - return graph_attribute_value_op(lhs, rhs); + GraphAttributeValue evaluate(AttrUnary const &expr) { + auto lhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.lhs); + auto rhs = expr.rhs; + return graph_attribute_value_op(expr.op_type, lhs, rhs); } - template - GraphAttributeValue evaluate(AttrBinary const &t) { - auto lhs = (*this)(t.lhs).value(); - auto rhs = (*this)(t.rhs).value(); - return graph_attribute_value_op(lhs, rhs); + GraphAttributeValue evaluate(AttrBinary const &expr) { + auto lhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.lhs); + auto rhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.rhs); + return graph_attribute_value_op(expr.op_type, lhs, rhs); } EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph, @@ -65,11 +89,10 @@ struct EvaluateGraphAttributeExpr { DiGraphPatternMatch const &match; }; -template GraphAttributeValue evaluate_graph_attribute_expr(ParallelComputationGraph const &graph, DiGraphPatternMatch const &match, - GraphAttributeExpr const &expr) { + GraphAttributeExpr const &expr) { return visit(EvaluateGraphAttributeExpr(graph, match), expr); } @@ -93,12 +116,12 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, ParallelComputationGraph::create(); bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { - if (!contains_r(match.nodeAssignment)) { + if (!contains_r(match.nodeAssignment, node)) { node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); } } for (MultiDiEdge const &edge : get_edges(pcg)) { - if (!contains_r(match.edgeAssignment)) { + if (!contains_r(match.edgeAssignment, edge)) { new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(edge.src), node_mapping.at_r(edge.dst), new_pcg.add_node_port(), diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h new file mode 100644 index 0000000000..d437e86530 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN +#define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN + +namespace FlexFlow { + +template +struct OutputLabelledOpenMultiDiGraph { + OutputLabelledOpenMultiDiGraph() = delete; + OutputLabelledOpenMultiDiGraph(OutputLabelledOpenMultiDiGraph const &) = default; + OutputLabelledOpenMultiDiGraph& operator=(OutputLabelledOpenMultiDiGraph const &) = default; + + operator OpenMultiDiGraphView(); +}; + +} + +#endif diff --git a/lib/utils/include/utils/graph/labelled_graphs.h b/lib/utils/include/utils/graph/labelled_graphs.h index d13a197fa1..699369edb3 100644 --- a/lib/utils/include/utils/graph/labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled_graphs.h @@ -8,6 +8,7 @@ #include "labelled/node_labelled.h" #include "labelled/open_algorithms.h" #include "labelled/output_labelled.h" +#include "labelled/output_labelled_open.h" #include "labelled/standard_labelled.h" #include "labelled/unordered_labelled_graphs.h" From a5e111e278ca9743bfe35c897ffd34e362256c39 Mon Sep 17 00:00:00 2001 From: wmdi Date: Fri, 25 Aug 2023 16:33:43 -0400 Subject: [PATCH 25/61] format --- .../utils/graph/labelled/output_labelled_open.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index d437e86530..5e5bd07976 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -3,15 +3,19 @@ namespace FlexFlow { -template +template struct OutputLabelledOpenMultiDiGraph { OutputLabelledOpenMultiDiGraph() = delete; - OutputLabelledOpenMultiDiGraph(OutputLabelledOpenMultiDiGraph const &) = default; - OutputLabelledOpenMultiDiGraph& operator=(OutputLabelledOpenMultiDiGraph const &) = default; + OutputLabelledOpenMultiDiGraph(OutputLabelledOpenMultiDiGraph const &) = + default; + OutputLabelledOpenMultiDiGraph & + operator=(OutputLabelledOpenMultiDiGraph const &) = default; operator OpenMultiDiGraphView(); }; -} +} // namespace FlexFlow #endif From 2fb2c7dfefa10f6c09dee74c625ae02804cc27c8 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 27 Aug 2023 21:18:42 -0400 Subject: [PATCH 26/61] minor fix --- lib/compiler/src/unity_algorithm.cc | 6 ++- .../include/substitutions/graph_pattern.h | 2 +- .../substitutions/graph_pattern_match.h | 23 ++++++++--- .../include/substitutions/output_graph.h | 8 ++-- .../substitutions/parallel_tensor_pattern.h | 9 ----- .../include/substitutions/substitution.h | 13 +++++-- lib/substitutions/src/graph_pattern.cc | 2 +- lib/substitutions/src/graph_pattern_match.cc | 32 +++++++-------- lib/substitutions/src/substitution.cc | 39 ++++++++++--------- 9 files changed, 73 insertions(+), 61 deletions(-) diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ef093fc11e..86fdd88d92 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,6 +1,6 @@ #include "compiler/unity_algorithm.h" #include "graph_utils.h" -#include "substitutions_implementation.h" +#include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" namespace FlexFlow { @@ -14,7 +14,9 @@ std::unordered_set std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, - Substitution const &) {} + Substitution const &) { + NOT_IMPLEMENTED(); +} Strategy graph_optimize(ComputationGraph &cg, diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index e2054f1a4f..7697ddf55d 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -22,7 +22,7 @@ bool is_singleton_pattern(OpenMultiDiGraphView const &); bool assignment_satisfies(ParallelComputationGraph const &, GraphPattern const &, - DiGraphPatternMatch const &); + MultiDiGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h index 449c26c846..498ec6cfd0 100644 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/graph_pattern_match.h @@ -6,22 +6,33 @@ namespace FlexFlow { -struct DiGraphPatternMatch { - bidict nodeAssignment; - bidict edgeAssignment; +struct MultiDiGraphPatternMatch { + using PatternNode = Node; + using PCGNode = Node; + using PatternEdge = OpenMultiDiEdge; + using PCGEdge = MultiDiEdge; + + bidict nodeAssignment; + bidict edgeAssignment; }; struct MatchSplit { - DiGraphPatternMatch prefix_submatch; - DiGraphPatternMatch postfix_submatch; + MultiDiGraphPatternMatch prefix_submatch; + MultiDiGraphPatternMatch postfix_submatch; }; template bool pattern_matches(OpenMultiDiGraphView const &, MultiDiGraphView const &, - DiGraphPatternMatch const &, + MultiDiGraphPatternMatch const &, F const &additional_criterion); +template +std::unordered_set + find_pattern_matches(OpenMultiDiGraphView const &pattern, + MultiDiGraphView const &graph, + F const &additional_criterion); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index 7b32ca9900..f5b6328d7d 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -46,16 +46,16 @@ using GraphAttributeExpr = // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. struct OperatorAttrAssignment { - std::vector> assignment; + std::unordered_map assignment; }; struct ParallelTensorAttrAssignment { - std::vector> assignment; + std::unordered_map assignment; }; -struct OutputGraph +struct OutputGraphExpr : public strong_typedef< - OutputGraph, + OutputGraphExpr, OutputLabelledOpenMultiDiGraph> { using strong_typedef::strong_typedef; diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index c62237d0fd..2b5f4d0f58 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -8,15 +8,6 @@ namespace FlexFlow { enum class TensorDimensionAttribute { SIZE, DEGREE }; -struct TensorNumDimensionsConstraint { - int value; -}; - -struct TensorDimensionAttributeConstraint { - TensorDimensionAttribute attribute; - int index; -}; - enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; using TensorAttributeValue = variant>; diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 55820da33f..a805d0dae1 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -7,15 +7,20 @@ namespace FlexFlow { struct Substitution { + using InputPatternInput = InputMultiDiEdge; + using InputPatternOutput = OutputMultiDiEdge; + using OutputPatternInput = InputMultiDiEdge; + using OutputPatternOutput = OutputMultiDiEdge; + GraphPattern input_graph; - OutputGraph output_graph; - bidict input_mapping; - bidict output_mapping; + OutputGraphExpr output_graph_expr; + bidict input_mapping; + bidict output_mapping; }; ParallelComputationGraph apply_substitution(ParallelComputationGraph const &, Substitution const &, - DiGraphPatternMatch const &); + MultiDiGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 9b7529ad8a..dfaf47910b 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -186,7 +186,7 @@ optional satisfies(ParallelTensor const ¶ms, bool assignment_satisfies(ParallelComputationGraph const &pcg, GraphPattern const &pattern, - DiGraphPatternMatch const &patternMatch) { + MultiDiGraphPatternMatch const &patternMatch) { bool result = true; for (auto const &kv : patternMatch.nodeAssignment) { auto patternNode = kv.first; diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc index a5c185aba0..2e0150c808 100644 --- a/lib/substitutions/src/graph_pattern_match.cc +++ b/lib/substitutions/src/graph_pattern_match.cc @@ -4,9 +4,9 @@ namespace FlexFlow { -// DiGraphPatternMatch narrow_match(DiGraphPatternMatch const &match, +// MultiDiGraphPatternMatch narrow_match(MultiDiGraphPatternMatch const &match, // OpenMultiDiGraphView const &pattern) { -// DiGraphPatternMatch result; +// MultiDiGraphPatternMatch result; // std::unordered_set nodes = get_nodes(pattern); // for (auto const &kv : match.nodeAssignment) { // Node pattern_node = kv.first; @@ -47,7 +47,7 @@ std::pair Given a match and a pattern split, gets the submatches in subpatterns. */ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, GraphSplit const &split) { auto prefix = split.first; auto postfix = split.second; @@ -99,7 +99,7 @@ bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { template bool pattern_matches(OpenMultiDiGraphView const &pattern, MultiDiGraphView const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, F const &additional_criterion) { if (is_singleton_pattern(pattern)) { Node pattern_node = get_only(get_nodes(pattern)); @@ -149,7 +149,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, additional_criterion); } -optional +optional get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, MultiDiGraphView const &graph, Node const &graph_node) { @@ -157,7 +157,7 @@ optional Node pattern_node = get_only(get_nodes(pattern)); - DiGraphPatternMatch match; + MultiDiGraphPatternMatch match; match.nodeAssignment.equate(pattern_node, graph_node); auto incoming = get_incoming_edges_by_idx(graph, graph_node); @@ -185,12 +185,12 @@ optional return match; } -optional unsplit_matches( - DiGraphPatternMatch const &prefix, - DiGraphPatternMatch const &postfix, +optional unsplit_matches( + MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, bidict> const &edge_splits) { - DiGraphPatternMatch result; + MultiDiGraphPatternMatch result; std::unordered_set handled; for (auto const &kv : edge_splits) { MultiDiEdge standard_edge = kv.first; @@ -222,14 +222,14 @@ optional unsplit_matches( } template -std::unordered_set +std::unordered_set find_pattern_matches(OpenMultiDiGraphView const &pattern, MultiDiGraphView const &graph, F const &additional_criterion) { - std::unordered_set matches; + std::unordered_set matches; if (is_singleton_pattern(pattern)) { for (Node const &graph_node : get_nodes(graph)) { - optional candidate = + optional candidate = get_candidate_singleton_match(pattern, graph, graph_node); if (candidate.has_value() || pattern_matches(pattern, graph, candidate.value())) { @@ -244,9 +244,9 @@ std::unordered_set auto postfix_matches = find_pattern_matches(subpatterns.first, graph, additional_criterion); auto edge_splits = get_edge_splits(pattern, split); - for (DiGraphPatternMatch const &prefix_match : prefix_matches) { - for (DiGraphPatternMatch const &postfix_match : postfix_matches) { - optional unsplit = + for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { + for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { + optional unsplit = unsplit_matches(prefix_match, postfix_match, edge_splits); if (unsplit.has_value()) { matches.insert(unsplit.value()); diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index c5cc870a7a..ff8d6ef541 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -3,7 +3,9 @@ namespace FlexFlow { -struct GraphAttributeValueOp { +struct GraphAttributeValueOpFunctor { + AttrOpType op; + template GraphAttributeValue operator()(T const &lhs, T const &rhs) { switch (op) { @@ -23,13 +25,12 @@ struct GraphAttributeValueOp { mk_runtime_error("Unknown attribute operator type"); } } - AttrOpType op; }; GraphAttributeValue graph_attribute_value_op(AttrOpType op, GraphAttributeValue const &lhs, GraphAttributeValue const &rhs) { - visit(GraphAttributeValueOp{op}, lhs, rhs); + visit(GraphAttributeValueOpFunctor{op}, lhs, rhs); } struct EvaluateGraphAttributeExprLeaf { @@ -53,12 +54,12 @@ struct EvaluateGraphAttributeExprLeaf { } ParallelComputationGraph const &graph; - DiGraphPatternMatch const &match; + MultiDiGraphPatternMatch const &match; }; GraphAttributeValue evaluate_graph_attribute_expr_leaf(ParallelComputationGraph const &g, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, GraphAttributeExprLeaf const &expr) { return visit(EvaluateGraphAttributeExprLeaf{g, match}, expr); } @@ -82,36 +83,37 @@ struct EvaluateGraphAttributeExpr { } EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match) + MultiDiGraphPatternMatch const &match) : graph(graph), match(match) {} ParallelComputationGraph const &graph; - DiGraphPatternMatch const &match; + MultiDiGraphPatternMatch const &match; }; GraphAttributeValue evaluate_graph_attribute_expr(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, GraphAttributeExpr const &expr) { return visit(EvaluateGraphAttributeExpr(graph, match), expr); } Operator get_operator_attrs(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, OperatorAttrAssignment const &assignment) { NOT_IMPLEMENTED(); } ParallelTensor get_parallel_tensor_attrs(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, ParallelTensorAttrAssignment const &assignment) { NOT_IMPLEMENTED(); } -ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, - Substitution const &substitution, - DiGraphPatternMatch const &match) { +ParallelComputationGraph + apply_substitution(ParallelComputationGraph const &pcg, + Substitution const &substitution, + MultiDiGraphPatternMatch const &match) { ParallelComputationGraph new_pcg = ParallelComputationGraph::create(); bidict node_mapping; // Refactor it with global nodes @@ -128,13 +130,13 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, new_pcg.add_node_port()}); } } - for (Node const &output_node : get_nodes(substitution.output_graph)) { + for (Node const &output_node : get_nodes(substitution.output_graph_expr)) { Node new_node = new_pcg.add_node(get_operator_attrs( - pcg, match, substitution.output_graph.at(output_node))); + pcg, match, substitution.output_graph_expr.at(output_node))); node_mapping.equate(output_node, new_node); } for (OpenMultiDiEdge const &output_edge : - get_edges(substitution.output_graph)) { + get_edges(substitution.output_graph_expr)) { if (holds_alternative(output_edge)) { MultiDiEdge origin_edge = match.edgeAssignment.at_r( substitution.input_mapping.at_r(output_edge)); @@ -157,11 +159,12 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, new_pcg.add_node_port()}); } } - for (MultiDiOutput const &output : get_outputs(substitution.output_graph)) { + for (MultiDiOutput const &output : + get_outputs(substitution.output_graph_expr)) { new_pcg.add_output( MultiDiOutput{node_mapping.at_l(output.src), new_pcg.add_node_port()}, get_parallel_tensor_attrs( - pcg, match, substitution.output_graph.at(output))); + pcg, match, substitution.output_graph_expr.at(output))); } return new_pcg; From a1bffc5f1544d2e1b468b2c0ddad9675e097bacb Mon Sep 17 00:00:00 2001 From: wmdi Date: Mon, 28 Aug 2023 14:33:59 -0400 Subject: [PATCH 27/61] updates --- lib/utils/include/utils/graph/algorithms.h | 6 +- .../graph/labelled/output_labelled_open.h | 48 ++++++++++++- lib/utils/include/utils/graph/open_graphs.h | 2 + lib/utils/src/graph/algorithms.cc | 68 +++++++++++-------- 4 files changed, 93 insertions(+), 31 deletions(-) diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index c83fcf0fdb..9b80683ef6 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -256,11 +256,13 @@ using GraphSplit = std::pair split_edge(MultiDiEdge const &e); MultiDiEdge unsplit_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &); +std::unordered_set get_cut_set(MultiDiGraphView const &, GraphSplit const &); +std::unordered_set get_cut_set(OpenMultiDiGraphView const &, + GraphSplit const &); + bidict> get_edge_splits(OpenMultiDiGraphView const &, GraphSplit const &); -std::unordered_set get_cut(OpenMultiDiGraphView const &, - GraphSplit const &); UndirectedGraphView get_subgraph(UndirectedGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 5e5bd07976..49b5b67c1e 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -13,7 +13,53 @@ struct OutputLabelledOpenMultiDiGraph { OutputLabelledOpenMultiDiGraph & operator=(OutputLabelledOpenMultiDiGraph const &) = default; - operator OpenMultiDiGraphView(); + operator OpenMultiDiGraphView() { + NOT_IMPLEMENTED(); + } + + Node add_node(NodeLabel const &) { + NOT_IMPLEMENTED(); + } + NodeLabel const &at(Node const &) const { + NOT_IMPLEMENTED(); + } + NodeLabel &at(Node const &) const { + NOT_IMPLEMENTED(); + } + + void add_edge(MultiDiEdge const &) { + NOT_IMPLEMENTED(); + } + void add_edge(InputMultiDiEdge const &) { + NOT_IMPLEMENTED(); + } + void add_edge(OutputMultiDiEdge const &) { + NOT_IMPLEMENTED(); + } + + InputLabel const &at(InputMultiDiEdge const &) const { + NOT_IMPLEMENTED(); + } + OutputLabel const &at(OutputMultiDiEdge const &) const { + NOT_IMPLEMENTED(); + } + + InputLabel &at(InputMultiDiEdge const &) { + NOT_IMPLEMENTED(); + } + OutputLabel &at(OutputMultiDiEdge const &) { + NOT_IMPLEMENTED(); + } + + void add_output(MultiDiOutput const &, OutputLabel const &) { + NOT_IMPLEMENTED(); + } + OutputLabel const &at(MultiDiOutput const &) const { + NOT_IMPLEMENTED(); + } + OutputLabel &at(MultiDiOutput const &) { + NOT_IMPLEMENTED(); + } }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h index 60c98a10bc..afb7946434 100644 --- a/lib/utils/include/utils/graph/open_graphs.h +++ b/lib/utils/include/utils/graph/open_graphs.h @@ -17,6 +17,8 @@ struct OpenMultiDiGraphView { OpenMultiDiGraphView() = delete; + operator MultiDiGraphView() const; + friend void swap(OpenMultiDiGraphView &, OpenMultiDiGraphView &); std::unordered_set query_nodes(NodeQuery const &); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 44335babcf..92e7fe16f1 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -22,16 +22,27 @@ std::unordered_set get_nodes(GraphView const &g) { return g.query_nodes(NodeQuery::all()); } -std::unordered_set get_nodes(OpenMultiDiEdge const &pattern_edge) { - if (is_input_edge(pattern_edge)) { - return {mpark::get(pattern_edge).dst}; - } else if (is_output_edge(pattern_edge)) { - return {mpark::get(pattern_edge).src}; - } else { - assert(is_standard_edge(pattern_edge)); - auto standard_edge = mpark::get(pattern_edge); - return {standard_edge.src, standard_edge.dst}; +std::unordered_set get_nodes(InputMultiDiEdge const &edge) { + return {edge.dst}; +} + +std::unordered_set get_nodes(OutputMultiDiEdge const &edge) { + return {edge.src}; +} + +std::unordered_set get_nodes(MultiDiEdge const &edge) { + return {edge.src, edge.src}; +} + +struct GetNodesFunctor { + template + std::unordered_set operator()(T const &t) { + return get_nodes(t); } +}; + +std::unordered_set get_nodes(OpenMultiDiEdge const &edge) { + return visit(GetNodesFunctor{}, edge); } std::unordered_set query_nodes(IGraphView const &g, @@ -480,35 +491,36 @@ MultiDiEdge unsplit_edge(OutputMultiDiEdge const &output_edge, output_edge.src, input_edge.dst, output_edge.srcIdx, input_edge.dstIdx}; } -bidict> - get_edge_splits(IOpenMultiDiGraphView const &pattern, - GraphSplit const &split) { +std::unordered_set get_cut_set(MultiDiGraphView const &graph, GraphSplit const &split) { auto prefix = split.first; auto postfix = split.second; - bidict> result; - - for (OpenMultiDiEdge const &pattern_edge : get_edges(pattern)) { - if (!is_standard_edge(pattern_edge)) { - continue; - } + std::unordered_set result; - auto standard_edge = mpark::get(pattern_edge); - if (is_subseteq_of(get_nodes(standard_edge), prefix) || - is_subseteq_of(get_nodes(standard_edge), postfix)) { - continue; + for (MultiDiEdge const &edge : get_edges(graph)) { + if (!is_subseteq_of(get_nodes(edge), prefix) && + !is_subseteq_of(get_nodes(edge), postfix)) { + result.insert(edge); } - - auto divided = split_edge(standard_edge); - result.equate(standard_edge, divided); } return result; } -std::unordered_set get_cut(OpenMultiDiGraphView const &g, - GraphSplit const &s) { - return keys(get_edge_splits(g, s)); +std::unordered_set get_cut_set(OpenMultiDiGraphView const &graph, + GraphSplit const &split) { + return get_cut_set(graph, split); +} + +bidict> + get_edge_splits(OpenMultiDiGraphView const &graph, + GraphSplit const &split) { + bidict> result; + std::unordered_set cut_set = get_cut_set(graph, split); + for (MultiDiEdge const &edge : cut_set) { + result.equate(edge, split_edge(edge)); + } + return result; } Node get_src_node(MultiDiEdge const &) { From ee9f7cac71ebba1c012b58f2a0210d91f1d4b063 Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 29 Aug 2023 16:27:42 -0400 Subject: [PATCH 28/61] readme for substitutions --- lib/substitutions/README.md | 34 ++++++++ .../substitutions/graph_pattern_match.h | 4 +- .../include/substitutions/output_graph.h | 19 +--- lib/substitutions/src/substitution.cc | 87 +++---------------- 4 files changed, 49 insertions(+), 95 deletions(-) create mode 100644 lib/substitutions/README.md diff --git a/lib/substitutions/README.md b/lib/substitutions/README.md new file mode 100644 index 0000000000..e0eb8aff18 --- /dev/null +++ b/lib/substitutions/README.md @@ -0,0 +1,34 @@ +# subtitutions + +## `Substitution` + +A substitution is to replace a subgraph of the PCG by a new one. We refer to the subgraph to be replaced as the input graph, and the new subgraph to replace the input graph as the output graph. + +A `Substitution` object describes a substitution. It consists of +* An `input_graph` of type `GraphPattern` that describes which kind of input graphs the substitution can be applied to; +* An `output_graph` of type `OutputGraphExpr` that describes how the output graph is computed from the input graph; and +* An `input_mapping` and `output_maping` that describes how the output graph is connected to the original PCG. + +### `GraphPattern` and `MultiDiGraphPatternMatch` + +A `GraphPattern` is defined as an open graph with node label `OperatorPattern` and output label `ParallelTensorPattern`, which is refered to as the pattern graph. The graph structure of a `GraphPattern` instance defines the geometrical property of the input graph, while the node labels and output labels define the attribute property of that. + +To apply a substitution to a PCG, we should first match the pattern graph to a subgraph of the PCG. `MultiDiGraphPatternMatch` describes the match, which consists of +* `node_assignment`: a mapping from the nodes of the pattern graph to the nodes of the PCG; and +* `edge_assignment`: a mapping from the edges of the pattern graph to the nodes of the PCG. +The input graph derived by this match is then defined by `values(node_assignment)` and `values(edge_assignment)`. A match is valid if and only if +* `node_assignment` and `edge_assignment` are injections; +* For every node `n` in the pattern graph, `edge_assignment` derives a bijection between `query_edges({n})` and `query_edges({node_assignment.at_l(n)})`. + +### `OutputGraphExpr` + +An `OutputGraphExpr` is defined as an open graph with node label `OperatorAttrAssignment` and output label `ParallelTensorAttrAssignment`, which defines how the operator attributes and the parallel tensor attributes of the output graph are derived from the input graph. + +`OperatorAttrAssignment` is a collection of `OperatorAttributeKey` and `GraphAttributeExpr` pairs. It defines how the attributes of a single operator is calculated from the input graph. A pair `{operator_attribute_key, graph_attribute_expr}` in the collection means the value of `graph_attribute_expr` is assigned to the attribute named `operator_attribute_key` of the operator. + +`ParallelTensorAttrAssignment` is defined in the similar way to `OperatorAttrAssignment`. + +`GraphAttributeExpr` is defined as one of `NodeAttrAccess`, `EdgeAttrAccess` and `AttrConstant`: +* `NodeAttrAccess` consists of a node `node` and an expression `attr_expr` on the attributes of the operator associated with the node. The value of a `NodeAttrAccess` instance is the value of `attr_expr` evaluated on the operator associated with the node. +* `EdgeAttrAccess` is defined in the similar way to `NodeAttrAccess`. +* `AttrConstant` consists of a constant `value`. The value of an `AttrConstant` instance is `value`. diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h index 498ec6cfd0..29d7f896a9 100644 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/graph_pattern_match.h @@ -12,8 +12,8 @@ struct MultiDiGraphPatternMatch { using PatternEdge = OpenMultiDiEdge; using PCGEdge = MultiDiEdge; - bidict nodeAssignment; - bidict edgeAssignment; + bidict node_assignment; + bidict edge_assignment; }; struct MatchSplit { diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index f5b6328d7d..417217c222 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -23,25 +23,8 @@ struct AttrConstant { GraphAttributeValue value; }; -using GraphAttributeExprLeaf = - variant; - -enum class AttrOpType { ADD, SUB, MUL, DIV }; - -struct AttrUnary { - AttrOpType op_type; - GraphAttributeExprLeaf lhs; - GraphAttributeValue rhs; -}; - -struct AttrBinary { - AttrOpType op_type; - GraphAttributeExprLeaf lhs; - GraphAttributeExprLeaf rhs; -}; - using GraphAttributeExpr = - variant; + variant; // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index ff8d6ef541..f1de7125bf 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -3,37 +3,10 @@ namespace FlexFlow { -struct GraphAttributeValueOpFunctor { - AttrOpType op; - - template - GraphAttributeValue operator()(T const &lhs, T const &rhs) { - switch (op) { - case AttrOpType::ADD: - return lhs + rhs; - break; - case AttrOpType::SUB: - return lhs - rhs; - break; - case AttrOpType::MUL: - return lhs * rhs; - break; - case AttrOpType::DIV: - return lhs / rhs; - break; - default: - mk_runtime_error("Unknown attribute operator type"); - } - } -}; - -GraphAttributeValue graph_attribute_value_op(AttrOpType op, - GraphAttributeValue const &lhs, - GraphAttributeValue const &rhs) { - visit(GraphAttributeValueOpFunctor{op}, lhs, rhs); -} +struct EvaluateGraphAttributeExpr { + ParallelComputationGraph const &graph; + MultiDiGraphPatternMatch const &match; -struct EvaluateGraphAttributeExprLeaf { template GraphAttributeValue operator()(T const &t) { return evaluate(t); @@ -41,60 +14,24 @@ struct EvaluateGraphAttributeExprLeaf { GraphAttributeValue evaluate(NodeAttrAccess const &t) { Node node_in_pattern = t.node; - Node node_in_pcg = match.nodeAssignment.at_l(node_in_pattern); + Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); return widen( evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value()); } GraphAttributeValue evaluate(EdgeAttrAccess const &t) { OpenMultiDiEdge output_in_pattern = t.edge; - MultiDiEdge output_in_pcg = match.edgeAssignment.at_l(output_in_pattern); + MultiDiEdge output_in_pcg = match.edge_assignment.at_l(output_in_pattern); return widen( evaluate_attribute_expr(graph.at(output_in_pcg), t.attr_expr).value()); } - - ParallelComputationGraph const &graph; - MultiDiGraphPatternMatch const &match; }; GraphAttributeValue - evaluate_graph_attribute_expr_leaf(ParallelComputationGraph const &g, + evaluate_graph_attribute_expr(ParallelComputationGraph const &g, MultiDiGraphPatternMatch const &match, - GraphAttributeExprLeaf const &expr) { - return visit(EvaluateGraphAttributeExprLeaf{g, match}, expr); -} - -struct EvaluateGraphAttributeExpr { - template - GraphAttributeValue operator()(T const &t) { - return evaluate(t); - } - - GraphAttributeValue evaluate(AttrUnary const &expr) { - auto lhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.lhs); - auto rhs = expr.rhs; - return graph_attribute_value_op(expr.op_type, lhs, rhs); - } - - GraphAttributeValue evaluate(AttrBinary const &expr) { - auto lhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.lhs); - auto rhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.rhs); - return graph_attribute_value_op(expr.op_type, lhs, rhs); - } - - EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph, - MultiDiGraphPatternMatch const &match) - : graph(graph), match(match) {} - - ParallelComputationGraph const &graph; - MultiDiGraphPatternMatch const &match; -}; - -GraphAttributeValue - evaluate_graph_attribute_expr(ParallelComputationGraph const &graph, - MultiDiGraphPatternMatch const &match, - GraphAttributeExpr const &expr) { - return visit(EvaluateGraphAttributeExpr(graph, match), expr); + GraphAttributeExpr const &expr) { + return visit(EvaluateGraphAttributeExpr{g, match}, expr); } Operator get_operator_attrs(ParallelComputationGraph const &graph, @@ -118,12 +55,12 @@ ParallelComputationGraph ParallelComputationGraph::create(); bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { - if (!contains_r(match.nodeAssignment, node)) { + if (!contains_r(match.node_assignment, node)) { node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); } } for (MultiDiEdge const &edge : get_edges(pcg)) { - if (!contains_r(match.edgeAssignment, edge)) { + if (!contains_r(match.edge_assignment, edge)) { new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(edge.src), node_mapping.at_r(edge.dst), new_pcg.add_node_port(), @@ -138,14 +75,14 @@ ParallelComputationGraph for (OpenMultiDiEdge const &output_edge : get_edges(substitution.output_graph_expr)) { if (holds_alternative(output_edge)) { - MultiDiEdge origin_edge = match.edgeAssignment.at_r( + MultiDiEdge origin_edge = match.edge_assignment.at_r( substitution.input_mapping.at_r(output_edge)); new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(origin_edge.src), node_mapping.at_l(output_edge.dst), new_pcg.add_node_port(), new_pcg.add_node_port()}); } else if (holds_alternative(output_edge)) { - MultiDiEdge origin_edge = match.edgeAssignment.at_r( + MultiDiEdge origin_edge = match.edge_assignment.at_r( substitution.output_mapping.at_r(output_edge)); new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src), node_mapping.at_l(origin_edge.dst), From 08dd3fed9df2124dcc25665fecfe954cd3b4bb0d Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 30 Aug 2023 15:13:42 -0400 Subject: [PATCH 29/61] format --- lib/substitutions/include/substitutions/substitution.h | 2 ++ lib/substitutions/src/substitution.cc | 4 ++-- lib/utils/include/utils/graph/algorithms.h | 8 ++++---- lib/utils/src/graph/algorithms.cc | 5 +++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index a805d0dae1..35bd03dbac 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -18,6 +18,8 @@ struct Substitution { bidict output_mapping; }; +bool is_valid_substitution(Substitution const &); + ParallelComputationGraph apply_substitution(ParallelComputationGraph const &, Substitution const &, MultiDiGraphPatternMatch const &); diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index f1de7125bf..1b56a0443f 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -29,8 +29,8 @@ struct EvaluateGraphAttributeExpr { GraphAttributeValue evaluate_graph_attribute_expr(ParallelComputationGraph const &g, - MultiDiGraphPatternMatch const &match, - GraphAttributeExpr const &expr) { + MultiDiGraphPatternMatch const &match, + GraphAttributeExpr const &expr) { return visit(EvaluateGraphAttributeExpr{g, match}, expr); } diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 9b80683ef6..6fe6b95aa0 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -256,14 +256,14 @@ using GraphSplit = std::pair split_edge(MultiDiEdge const &e); MultiDiEdge unsplit_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &); -std::unordered_set get_cut_set(MultiDiGraphView const &, GraphSplit const &); +std::unordered_set get_cut_set(MultiDiGraphView const &, + GraphSplit const &); std::unordered_set get_cut_set(OpenMultiDiGraphView const &, - GraphSplit const &); - + GraphSplit const &); + bidict> get_edge_splits(OpenMultiDiGraphView const &, GraphSplit const &); - UndirectedGraphView get_subgraph(UndirectedGraphView const &, std::unordered_set const &); DiGraphView get_subgraph(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 92e7fe16f1..8fdd18ddd9 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -491,7 +491,8 @@ MultiDiEdge unsplit_edge(OutputMultiDiEdge const &output_edge, output_edge.src, input_edge.dst, output_edge.srcIdx, input_edge.dstIdx}; } -std::unordered_set get_cut_set(MultiDiGraphView const &graph, GraphSplit const &split) { +std::unordered_set get_cut_set(MultiDiGraphView const &graph, + GraphSplit const &split) { auto prefix = split.first; auto postfix = split.second; @@ -508,7 +509,7 @@ std::unordered_set get_cut_set(MultiDiGraphView const &graph, Graph } std::unordered_set get_cut_set(OpenMultiDiGraphView const &graph, - GraphSplit const &split) { + GraphSplit const &split) { return get_cut_set(graph, split); } From 82e2c2c061ed8115ab7e167efa7bd584b9854a5f Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 30 Aug 2023 16:59:06 -0400 Subject: [PATCH 30/61] check substitution validity --- lib/substitutions/src/substitution.cc | 51 ++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 1b56a0443f..e83d522cd7 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -1,8 +1,57 @@ #include "substitutions/substitution.h" -#include namespace FlexFlow { +std::unordered_set> + get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { + NOT_IMPLEMENTED(); +} + +bool is_valid_operator_attribute_expr( + OperatorPattern const &pattern, + AttributeExpr const &expr) { + return contains(get_valid_operator_attribute_exprs(pattern), expr); +} + +struct IsValidGraphAttributeExprFunctor { + GraphPattern const &graph_pattern; + + template + bool operator()(T const &t) const { + return is_valid(t); + } + + bool is_valid(NodeAttrAccess const &t) const { + return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), + t.attr_expr); + } + + bool is_valid(EdgeAttrAccess const &t) const { + NOT_IMPLEMENTED(); + } + + bool is_valid(AttrConstant const &t) const { + return true; + } +}; + +bool is_valid_graph_attribute_expr(GraphPattern const &pattern, + GraphAttributeExpr const &expr) { + return visit(IsValidGraphAttributeExprFunctor{pattern}, expr); +} + +bool is_valid_substitution(Substitution const &s) { + for (Node const &node : get_nodes(s.output_graph_expr)) { + for (GraphAttributeExpr expr : + values(s.output_graph_expr.value().at(node).assignment)) { + if (!is_valid_graph_attribute_expr(s.input_graph, expr)) { + return false; + } + } + } + return true; +} + struct EvaluateGraphAttributeExpr { ParallelComputationGraph const &graph; MultiDiGraphPatternMatch const &match; From ae97d594081ef478b94f24a550129c72cbab0f33 Mon Sep 17 00:00:00 2001 From: wmdi Date: Thu, 31 Aug 2023 16:19:04 -0400 Subject: [PATCH 31/61] initialize tests for substitutions --- lib/compiler/test/test_generator.h | 19 ++++-- lib/substitutions/src/graph_pattern_match.cc | 54 +++++++-------- lib/substitutions/src/substitution.cc | 11 +-- .../test/test_pattern_matches.cc | 67 +++++++++++++++++++ lib/utils/include/utils/graph/algorithms.h | 1 + .../include/utils/graph/labelled_graphs.h | 1 + lib/utils/include/utils/graph/multidigraph.h | 1 + 7 files changed, 116 insertions(+), 38 deletions(-) create mode 100644 lib/substitutions/test/test_pattern_matches.cc diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h index b3453b014c..374bb89455 100644 --- a/lib/compiler/test/test_generator.h +++ b/lib/compiler/test/test_generator.h @@ -3,6 +3,7 @@ #include "compiler/machine_mapping.h" #include "compiler/sub_parallel_computation_graph.h" +#include "pcg/computation_graph.h" #include "rapidcheck.h" using namespace FlexFlow; @@ -38,7 +39,7 @@ ParallelComputationGraph } rc::Gen small_integer_generator() { - return gen::inRange(1, 4); + return rc::gen::inRange(1, 4); } namespace rc { @@ -51,15 +52,16 @@ Gen serialParallelMultiDiGraph() { template <> struct Arbitrary { static Gen arbitrary() { - return gen::map(serialParallelMultiDiGraph, test_computataion_graph); + return gen::map(gen::cast(serialParallelMultiDiGraph()), + test_computataion_graph); } }; template <> struct Arbitrary { static Gen arbitrary() { - return gen::map(serialParallelMultiDiGraph, - test_parallel_computataion_graph); + return gen::map(gen::cast(serialParallelMultiDiGraph()), + test_parallel_computation_graph); } }; @@ -67,7 +69,9 @@ template <> struct Arbitrary> { static Gen> arbitrary() { return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node ? gen::arbitrary() : gen::arbitrary(); + return is_node + ? gen::cast>(gen::arbitrary()) + : gen::cast>(gen::arbitrary()); }); } }; @@ -76,7 +80,10 @@ template <> struct Arbitrary> { static Gen> arbitrary() { return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node ? gen::arbitrary() : gen::arbitrary(); + return is_node + ? gen::cast>(gen::arbitrary()) + : gen::cast>( + gen::arbitrary()); }); } }; diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc index 2e0150c808..81015620bd 100644 --- a/lib/substitutions/src/graph_pattern_match.cc +++ b/lib/substitutions/src/graph_pattern_match.cc @@ -8,18 +8,18 @@ namespace FlexFlow { // OpenMultiDiGraphView const &pattern) { // MultiDiGraphPatternMatch result; // std::unordered_set nodes = get_nodes(pattern); -// for (auto const &kv : match.nodeAssignment) { +// for (auto const &kv : match.node_assignment) { // Node pattern_node = kv.first; // if (contains(nodes, pattern_node)) { -// result.nodeAssignment.equate(kv.first, kv.second); +// result.node_assignment.equate(kv.first, kv.second); // } // } // std::unordered_set edges = get_edges(pattern); -// for (auto const &kv : match.edgeAssignment) { +// for (auto const &kv : match.edge_assignment) { // OpenMultiDiEdge pattern_edge = kv.first; // if (contains(edges, pattern_edge)) { -// result.edgeAssignment.equate(kv.first, kv.second); +// result.edge_assignment.equate(kv.first, kv.second); // } // } @@ -54,14 +54,14 @@ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, MatchSplit result; - for (auto const &kv : match.nodeAssignment) { + for (auto const &kv : match.node_assignment) { Node pattern_node = kv.first; Node graph_node = kv.second; if (contains(split.first, pattern_node)) { - result.prefix_submatch.nodeAssignment.equate(pattern_node, graph_node); + result.prefix_submatch.node_assignment.equate(pattern_node, graph_node); } else { assert(contains(split.second, pattern_node)); - result.postfix_submatch.nodeAssignment.equate(pattern_node, graph_node); + result.postfix_submatch.node_assignment.equate(pattern_node, graph_node); } } @@ -69,12 +69,12 @@ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, std::function handle_edge = [&](OpenMultiDiEdge const &pattern_edge) -> void { - MultiDiEdge graph_edge = match.edgeAssignment.at_l(pattern_edge); + MultiDiEdge graph_edge = match.edge_assignment.at_l(pattern_edge); auto edge_nodes = get_nodes(pattern_edge); if (is_subseteq_of(edge_nodes, prefix)) { - result.prefix_submatch.edgeAssignment.equate(pattern_edge, graph_edge); + result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); } else if (is_subseteq_of(edge_nodes, postfix)) { - result.postfix_submatch.edgeAssignment.equate(pattern_edge, graph_edge); + result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); } else { assert(is_standard_edge(pattern_edge)); auto standard_edge = mpark::get(pattern_edge); @@ -84,7 +84,7 @@ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, } }; - for (auto const &kv : match.edgeAssignment) { + for (auto const &kv : match.edge_assignment) { OpenMultiDiEdge pattern_edge = kv.first; handle_edge(pattern_edge); } @@ -103,24 +103,24 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, F const &additional_criterion) { if (is_singleton_pattern(pattern)) { Node pattern_node = get_only(get_nodes(pattern)); - Node graph_matched_node = match.nodeAssignment.at_l(pattern_node); + Node graph_matched_node = match.node_assignment.at_l(pattern_node); if (!additional_criterion(pattern_node, graph_matched_node)) { return false; } for (OpenMultiDiEdge const &e : get_edges(pattern)) { - MultiDiEdge graph_matched_edge = match.edgeAssignment.at_l(e); + MultiDiEdge graph_matched_edge = match.edge_assignment.at_l(e); assert(is_input_edge(e) || is_output_edge(e)); if (is_input_edge(e)) { InputMultiDiEdge input_edge = mpark::get(e); - if (match.nodeAssignment.at_l(input_edge.dst) != + if (match.node_assignment.at_l(input_edge.dst) != graph_matched_edge.dst || input_edge.dstIdx != graph_matched_edge.dstIdx) { return false; } } else { OutputMultiDiEdge output_edge = mpark::get(e); - if (match.nodeAssignment.at_l(output_edge.src) != + if (match.node_assignment.at_l(output_edge.src) != graph_matched_edge.src || output_edge.srcIdx != graph_matched_edge.srcIdx) { return false; @@ -158,7 +158,7 @@ optional Node pattern_node = get_only(get_nodes(pattern)); MultiDiGraphPatternMatch match; - match.nodeAssignment.equate(pattern_node, graph_node); + match.node_assignment.equate(pattern_node, graph_node); auto incoming = get_incoming_edges_by_idx(graph, graph_node); auto outgoing = get_outgoing_edges_by_idx(graph, graph_node); @@ -169,16 +169,16 @@ optional if (!contains_key(incoming, input_edge.dstIdx)) { return nullopt; } - match.edgeAssignment.equate(input_edge, - get_only(incoming.at(input_edge.dstIdx))); + match.edge_assignment.equate(input_edge, + get_only(incoming.at(input_edge.dstIdx))); } else { OutputMultiDiEdge output_edge = mpark::get(pattern_edge); if (!contains_key(outgoing, output_edge.srcIdx)) { return nullopt; } - match.edgeAssignment.equate(output_edge, - get_only(outgoing.at(output_edge.srcIdx))); + match.edge_assignment.equate(output_edge, + get_only(outgoing.at(output_edge.srcIdx))); } } @@ -199,24 +199,24 @@ optional unsplit_matches( handled.insert(output_edge); handled.insert(input_edge); - MultiDiEdge output_graph_edge = prefix.edgeAssignment.at_l(output_edge); - MultiDiEdge input_graph_edge = postfix.edgeAssignment.at_l(input_edge); + MultiDiEdge output_graph_edge = prefix.edge_assignment.at_l(output_edge); + MultiDiEdge input_graph_edge = postfix.edge_assignment.at_l(input_edge); if (output_graph_edge == input_graph_edge) { - result.edgeAssignment.equate(standard_edge, output_graph_edge); + result.edge_assignment.equate(standard_edge, output_graph_edge); } else { return nullopt; } } for (auto const &kv : - merge_maps(prefix.edgeAssignment, postfix.edgeAssignment)) { + merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { if (!contains(handled, kv.first)) { - result.edgeAssignment.equate(kv.first, kv.second); + result.edge_assignment.equate(kv.first, kv.second); } } - result.nodeAssignment = - merge_maps(prefix.nodeAssignment, postfix.nodeAssignment); + result.node_assignment = + merge_maps(prefix.node_assignment, postfix.node_assignment); return result; } diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index e83d522cd7..c81221daa0 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -22,7 +22,7 @@ struct IsValidGraphAttributeExprFunctor { } bool is_valid(NodeAttrAccess const &t) const { - return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), + return is_valid_operator_attribute_expr(graph_pattern->at(t.node), t.attr_expr); } @@ -65,14 +65,14 @@ struct EvaluateGraphAttributeExpr { Node node_in_pattern = t.node; Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); return widen( - evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value()); + evaluate_attribute_expr(graph->at(node_in_pcg), t.attr_expr).value()); } GraphAttributeValue evaluate(EdgeAttrAccess const &t) { OpenMultiDiEdge output_in_pattern = t.edge; MultiDiEdge output_in_pcg = match.edge_assignment.at_l(output_in_pattern); return widen( - evaluate_attribute_expr(graph.at(output_in_pcg), t.attr_expr).value()); + evaluate_attribute_expr(graph->at(output_in_pcg), t.attr_expr).value()); } }; @@ -101,11 +101,12 @@ ParallelComputationGraph Substitution const &substitution, MultiDiGraphPatternMatch const &match) { ParallelComputationGraph new_pcg = - ParallelComputationGraph::create(); + OutputLabelledMultiDiGraph::create< + UnorderedOutputLabelledMultiDiGraph>(); bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { if (!contains_r(match.node_assignment, node)) { - node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); + node_mapping.equate(node, new_pcg.add_node(pcg.value().at(node))); } } for (MultiDiEdge const &edge : get_edges(pcg)) { diff --git a/lib/substitutions/test/test_pattern_matches.cc b/lib/substitutions/test/test_pattern_matches.cc new file mode 100644 index 0000000000..be95fff8e6 --- /dev/null +++ b/lib/substitutions/test/test_pattern_matches.cc @@ -0,0 +1,67 @@ +#include "doctest.h" +#include "rapidcheck.h" +#include "graph_pattern_match.h" + +using namespace FlexFlow; + +struct AlwaysTrue { + template + bool operator()(T const &t) const { + return true; + } +}; + +MultiDiGraph construct_multidigraph(int num_nodes, std::vector> const &edges) { + MultiDiGraph g = MultiDiGraph::create(); + + std::vector nodes; + for (int i = 0; i < num_nodes; ++i) { + nodes.push_back(g.add_node()); + } + + for (std::pair e : edges) { + if (e.first > e.second) { + std::swap(e.first, e.second); + } + + g.add_edge(MultiDiEdge{nodes[e.first], nodes[e.second], g.add_node_port(), g.add_node_port()}); + } + + return g; +} + +namespace rc { + +template <> +struct Arbitrary { + static const int MAX_GRAPH_SIZE = 200; + + static Gen arbitrary() { + auto gen_edges = [](int num_nodes) { + return gen::container>>(gen::inRange(0, num_nodes)); + }; + + auto gen_graph = [&](int num_nodes) { + return gen::apply([=](std::vector> const &edges) { return construct_multidigraph(num_nodes, edges); }, gen_edges(num_nodes)); + }; + + return gen::apply(gen_graph, gen::inRange(1, MAX_GRAPH_SIZE)); + } +}; + +} + +TEST_CASE("find_pattern_matches") { + rc::check([](MultiDiGraph const &g) { + std::unordered_set subgraph_nodes = *rc::gen::container>(rc::gen::elementOf(get_nodes(g))); + OpenMultiDiGraphView subgraph = get_subgraph(as_openmultidigraph(g), subgraph_nodes); + + std::unordered_set matches = find_pattern_matches(subgraph, g, AlwaysTrue{}); + + RC_ASSERT(!matches.empty()); + + for (MultiDiGraphPatternMatch const &match : matches) { + RC_ASSERT(pattern_matches(subgraph, g, match, AlwaysTrue{})); + } + }); +} \ No newline at end of file diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 6fe6b95aa0..5dc1fb15c8 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -286,6 +286,7 @@ UndirectedGraphView as_undirected(DiGraphView const &); MultiDiGraphView as_multidigraph(DiGraphView const &); DiGraphView as_digraph(MultiDiGraphView const &); MultiDiGraphView as_multidigraph(OpenMultiDiGraphView const &); +OpenMultiDiGraphView as_openmultidigraph(MultiDiGraphView const &); void export_as_dot( DotFile &, diff --git a/lib/utils/include/utils/graph/labelled_graphs.h b/lib/utils/include/utils/graph/labelled_graphs.h index 699369edb3..3fcc1daae5 100644 --- a/lib/utils/include/utils/graph/labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled_graphs.h @@ -11,5 +11,6 @@ #include "labelled/output_labelled_open.h" #include "labelled/standard_labelled.h" #include "labelled/unordered_labelled_graphs.h" +#include "labelled/views.h" #endif diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h index 5a0ae65490..4cbf33446f 100644 --- a/lib/utils/include/utils/graph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.h @@ -52,6 +52,7 @@ struct MultiDiGraph { MultiDiGraph(MultiDiGraph const &) = default; MultiDiGraph &operator=(MultiDiGraph const &) = default; + operator GraphView() const; operator MultiDiGraphView() const; friend void swap(MultiDiGraph &, MultiDiGraph &); From 31d2ca05c20881f6ffa1acb12a9efbe9b3a85c33 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 31 Aug 2023 23:45:47 -0700 Subject: [PATCH 32/61] Add partial required fix --- lib/utils/include/utils/required.h | 32 +++-- lib/utils/include/utils/required_core.h | 115 ++++++++++++++--- lib/utils/include/utils/test_types.h | 42 ++++++- lib/utils/include/utils/type_traits.h | 81 ------------ lib/utils/include/utils/type_traits_core.h | 118 ++++++++++++++++++ .../test/common/include/test/utils/doctest.h | 4 + lib/utils/test/src/test_required.cc | 92 ++++++++++++++ 7 files changed, 369 insertions(+), 115 deletions(-) create mode 100644 lib/utils/test/src/test_required.cc diff --git a/lib/utils/include/utils/required.h b/lib/utils/include/utils/required.h index 499994770a..4165b442fa 100644 --- a/lib/utils/include/utils/required.h +++ b/lib/utils/include/utils/required.h @@ -24,24 +24,34 @@ struct adl_serializer<::FlexFlow::req> { }; } // namespace nlohmann -namespace fmt { +/* namespace fmt { */ -template -struct formatter<::FlexFlow::req> : formatter { - template - auto format(::FlexFlow::req const &t, FormatContext &ctx) - -> decltype(ctx.out()) { - return formatter::format(static_cast(t), ctx); - } -}; +/* template */ +/* struct formatter<::FlexFlow::req> : formatter { */ +/* template */ +/* auto format(::FlexFlow::req const &t, FormatContext &ctx) */ +/* -> decltype(ctx.out()) { */ +/* return formatter::format(static_cast(t), ctx); */ +/* } */ +/* }; */ -} // namespace fmt +/* } // namespace fmt */ namespace FlexFlow { static_assert(is_json_serializable>::value, ""); static_assert(is_json_deserializable>::value, ""); static_assert(is_jsonable>::value, ""); -static_assert(is_fmtable>::value, ""); +CHECK_FMTABLE(req); +CHECK_FMTABLE(std::vector); +CHECK_FMTABLE(required_inheritance_impl>); +static_assert( + std::is_base_of< + required_inheritance_impl>, + req> + >::value, "" +); +CHECK_FMTABLE(req>); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/required_core.h b/lib/utils/include/utils/required_core.h index 3336e38243..8d20597efc 100644 --- a/lib/utils/include/utils/required_core.h +++ b/lib/utils/include/utils/required_core.h @@ -4,9 +4,22 @@ #include "hash-utils-core.h" #include "type_traits_core.h" #include +#include "test_types.h" +#include "fmt.decl.h" namespace FlexFlow { +template +struct enable_if_valid {}; + +template +struct enable_if_valid, Args...> : type_identity {}; + +/* required_wrapper_impl() + std::declval())>> */ +/* operator+(required_wrapper_impl const &lhs, required_wrapper_impl const &rhs) { */ +/* /1* return 1; *1/ */ +/* } */ + template struct required_wrapper_impl { public: @@ -34,6 +47,10 @@ struct required_wrapper_impl { return static_cast(this->m_value); } + friend T format_as(required_wrapper_impl const &r) { + return static_cast(r); + } + /* T const &operator*() const { */ /* return this->m_value; */ /* } */ @@ -42,12 +59,54 @@ struct required_wrapper_impl { /* return &this->m_value; */ /* } */ - /* bool operator==(T const &other) const { */ - /* return this->m_value == other; */ + template + enable_if_t::value, bool> + operator==(required_wrapper_impl const &rhs) const { + return this->m_value == rhs.m_value; + } + + template + enable_if_t::value, bool> + operator==(TT const &rhs) const { + return this->m_value == rhs; + } + + /* friend enable_if_t::value, bool> */ + /* operator==(required_wrapper_impl const &lhs, T const &rhs) { */ + /* return lhs.m_value == rhs; */ + /* } */ + + /* friend enable_if_t::value, bool> */ + /* operator==(T const &lhs, required_wrapper_impl const &rhs) { */ + /* return lhs == rhs.m_value; */ + /* } */ + + template + enable_if_t::value, bool> + operator!=(required_wrapper_impl const &rhs) const { + return this->m_value != rhs.m_value; + } + + /* friend enable_if_t::value, required_wrapper_impl() + std::declval())>> */ + /* operator+(required_wrapper_impl const &lhs, required_wrapper_impl const &rhs) { */ + /* /1* return 1; *1/ */ + /* } */ + /* required_wrapper_impl */ + /* operator+(required_wrapper_impl const &rhs) { */ + /* Out o = this->m_value + rhs.m_value; */ + /* return required_wrapper_impl{o}; */ /* } */ - /* bool operator!=(T const &other) const { */ - /* return this->m_value != other; */ + /* template ::value> = true> */ + /* required_wrapper_impl operator-(required_wrapper_impl const &rhs) { */ + /* return {this->m_value - rhs.m_value}; */ + /* } */ + + /* template ::value> = true> */ + /* required_wrapper_impl operator*(required_wrapper_impl const &rhs) { */ + /* return {this->m_value * rhs.m_value}; */ /* } */ /* bool operator<(T const &other) const { */ @@ -68,8 +127,31 @@ struct required_inheritance_impl : public T { using T::T; required_inheritance_impl() = delete; - required_inheritance_impl(T const &); - required_inheritance_impl(T &&t); + required_inheritance_impl(T const &t) : T(t) {} + required_inheritance_impl(T &&t) : T(t) {} + + required_inheritance_impl(required_inheritance_impl const &) = default; + required_inheritance_impl(required_inheritance_impl &&) = default; + + required_inheritance_impl &operator=(required_inheritance_impl const &) = default; + required_inheritance_impl &operator=(required_inheritance_impl &&) = default; + + friend enable_if_t::value, bool> + operator==(required_inheritance_impl const &lhs, required_inheritance_impl const &rhs) { + return static_cast(lhs) == static_cast(rhs); + } + + friend enable_if_t::value, bool> + operator!=(required_inheritance_impl const &lhs, required_inheritance_impl const &rhs) { + return static_cast(lhs) != static_cast(rhs); + } + + friend std::string format_as(required_inheritance_impl const &r) { + return ""; + /* static_assert(is_fmtable::value, ""); */ + + /* return static_cast(r); */ + } template required_inheritance_impl( @@ -78,8 +160,6 @@ struct required_inheritance_impl : public T { !std::is_same::value>::type * = 0) : required_inheritance_impl(static_cast(tt)) {} - operator T() const; - template ::value && !std::is_same::value), @@ -116,19 +196,16 @@ struct remove_req> { template using remove_req_t = typename remove_req::type; +static_assert(is_equal_comparable>>::value, ""); +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(required_inheritance_impl); +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(required_wrapper_impl); + +/* static_assert(std::is_same>() + std::declval>()), required_wrapper_impl>::value, ""); */ + +static_assert(std::is_copy_constructible>::value, ""); + static_assert(std::is_convertible, int>::value, ""); static_assert(is_static_castable, int *>::value, ""); -static_assert( - std::is_same< - void_t>() == std::declval())>, - void>::value, - ""); -static_assert(is_list_initializable, bool>::value, ""); -static_assert( - std::is_same< - void_t>() + std::declval())>, - void>::value, - ""); } // namespace FlexFlow diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h index 2cac547bb6..308bd61428 100644 --- a/lib/utils/include/utils/test_types.h +++ b/lib/utils/include/utils/test_types.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_TEST_TYPES_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_TEST_TYPES_H -#include "type_traits.h" +#include "type_traits_core.h" namespace FlexFlow { @@ -12,7 +12,10 @@ enum capability { EQ, CMP, DEFAULT_CONSTRUCTIBLE, - COPYABLE, + MOVE_CONSTRUCTIBLE, + MOVE_ASSIGNABLE, + COPY_CONSTRUCTIBLE, + COPY_ASSIGNABLE, PLUS, PLUSEQ, FMT @@ -51,14 +54,38 @@ struct test_type_t { typename std::enable_if::value, bool>::type = true> test_type_t() = delete; - template ::value, bool>::type = true> test_type_t(test_type_t const &); - template ::value, bool>::type = true> test_type_t(test_type_t const &) = delete; + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t const &); + + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t const &) = delete; + + template ::value, bool>::type = true> + test_type_t(test_type_t &&); + + template ::value, bool>::type = true> + test_type_t(test_type_t &&) = delete; + + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t &&); + + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t &&) = delete; + template typename std::enable_if::value, bool>::type operator==(test_type_t const &) const; @@ -102,6 +129,13 @@ using cmp = test_type_t; using hash_cmp = test_type_t; using plusable = test_type_t; using fmtable = test_type_t; +using well_behaved_value_type = test_type_t< + EQ, + COPY_CONSTRUCTIBLE, + MOVE_CONSTRUCTIBLE, + COPY_ASSIGNABLE, + MOVE_ASSIGNABLE +>; } // namespace test_types } // namespace FlexFlow diff --git a/lib/utils/include/utils/type_traits.h b/lib/utils/include/utils/type_traits.h index dc8fe2cf57..ee45e8dc2e 100644 --- a/lib/utils/include/utils/type_traits.h +++ b/lib/utils/include/utils/type_traits.h @@ -65,24 +65,6 @@ template struct is_streamable())>> : std::true_type {}; -template -struct is_equal_comparable : std::false_type {}; - -template -struct is_equal_comparable< - T, - void_t() == std::declval()))>> - : std::true_type {}; - -template -struct is_neq_comparable : std::false_type {}; - -template -struct is_neq_comparable< - T, - void_t() != std::declval()))>> - : std::true_type {}; - template struct is_lt_comparable : std::false_type {}; @@ -92,19 +74,6 @@ struct is_lt_comparable< void_t() < std::declval()))>> : std::true_type {}; -template -struct is_hashable : std::false_type {}; - -template -struct is_hashable< - T, - void_t>()(std::declval())))>> - : std::true_type {}; - -#define CHECK_HASHABLE(...) \ - static_assert(is_hashable<__VA_ARGS__>::value, \ - #__VA_ARGS__ " should be hashable (but is not)"); - template