diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index aeec1362dd..4089260735 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -5,33 +5,65 @@ #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph.h" +#include "sub_parallel_computation_graph.h" namespace FlexFlow { struct MachineMapping { - static MachineMapping sequential_combine(MachineMapping const &s1, - MachineMapping const &s2); - static MachineMapping parallel_combine(MachineMapping const &s1, - MachineMapping const &s2); - static MachineMapping infinity(); + static MachineMapping combine(MachineMapping const &, MachineMapping const &); + static bool nodes_are_disjoint(MachineMapping const &m1, + MachineMapping const &m2); - float runtime; req> machine_views; }; -FF_VISITABLE_STRUCT(MachineMapping, runtime, machine_views); +FF_VISITABLE_STRUCT(MachineMapping, machine_views); + +struct OptimalCostState { + SerialParallelDecomposition subgraph; + MachineSpecification resource; + req> source_machine_view, sink_machine_view; +}; +FF_VISITABLE_STRUCT(OptimalCostState, + subgraph, + resource, + source_machine_view, + sink_machine_view); + +struct OptimalCostResult { + static OptimalCostResult sequential_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2); + static OptimalCostResult parallel_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2); + static OptimalCostResult infinity(); + + float runtime; + MachineMapping machine_mapping; +}; +FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); + +struct OptimalCostRuntimeCmp { + bool operator()(OptimalCostResult const &, OptimalCostResult const &); +}; + +class OptimalCostCache { +public: + OptimalCostCache() = default; + + optional load(OptimalCostState const &) const; + void save(OptimalCostState const &, OptimalCostResult const &); -struct MachineMappingRuntimeCmp { - bool operator()(MachineMapping const &, MachineMapping const &); +private: + std::unordered_map cache; }; -MachineMapping optimal_cost( - ParallelComputationGraph const &g, - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - std::unordered_map &cached_subgraph_costs); +OptimalCostResult + optimal_cost(ParallelComputationGraph const &g, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + OptimalCostCache &cached_subgraph_costs); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index e8c67c38ae..57f1c8c063 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -13,8 +13,9 @@ struct Substitution {}; struct Strategy { ParallelComputationGraph pcg; MachineMapping machine_mapping; + req runtime; }; -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping); +FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); struct StrategyRuntimeCmp { bool operator()(Strategy const &, Strategy const &); diff --git a/lib/compiler/src/graph_utils.h b/lib/compiler/src/graph_utils.h index 91132da680..88515ef950 100644 --- a/lib/compiler/src/graph_utils.h +++ b/lib/compiler/src/graph_utils.h @@ -16,7 +16,16 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); template void minimize(T &t, T const &v) { - t = std::min(t, v); + if (v < t) { + t = v; + } +} + +template +void minimize(T &t, T const &v, Compare comp) { + if (comp(v, t)) { + t = v; + } } } // namespace FlexFlow diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 10be9bb034..2f6af8a62b 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -7,28 +7,56 @@ namespace FlexFlow { -MachineMapping MachineMapping::sequential_combine(MachineMapping const &s1, - MachineMapping const &s2) { - return {s1.runtime + s2.runtime, - merge_maps(s1.machine_views, s2.machine_views)}; +MachineMapping MachineMapping::combine(MachineMapping const &s1, + MachineMapping const &s2) { + return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; } -MachineMapping MachineMapping::parallel_combine(MachineMapping const &s1, - MachineMapping const &s2) { - return {std::max(s1.runtime, s2.runtime), - merge_maps(s1.machine_views, s2.machine_views)}; +bool MachineMapping::nodes_are_disjoint(MachineMapping const &m1, + MachineMapping const &m2) { + return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); } -MachineMapping MachineMapping::infinity() { +OptimalCostResult + OptimalCostResult::sequential_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2) { + return OptimalCostResult{ + s1.runtime + s2.runtime, + MachineMapping::combine(s1.machine_mapping, s2.machine_mapping)}; +} + +OptimalCostResult + OptimalCostResult::parallel_combine(OptimalCostResult const &s1, + OptimalCostResult const &s2) { + return OptimalCostResult{ + std::max(s1.runtime, s2.runtime), + MachineMapping::combine(s1.machine_mapping, s2.machine_mapping)}; +} + +OptimalCostResult OptimalCostResult::infinity() { return {std::numeric_limits::infinity(), - std::unordered_map{}}; + MachineMapping{std::unordered_map{}}}; } -bool MachineMappingRuntimeCmp::operator()(MachineMapping const &lhs, - MachineMapping const &rhs) { +bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, + OptimalCostResult const &rhs) { return lhs.runtime < rhs.runtime; } +optional + OptimalCostCache::load(OptimalCostState const &state) const { + if (contains_key(cache, state)) { + return make_optional(cache.at(state)); + } + return nullopt; +} + +void OptimalCostCache::save(OptimalCostState const &state, + OptimalCostResult const &result) { + assert(!contains_key(cache, state)); + cache.emplace(state, result); +} + std::vector> get_resource_split(MachineSpecification const &resource) { std::vector> result; @@ -65,7 +93,7 @@ std::pair OpenMultiDiGraphView g1 = get_subgraph(g, split.first); OpenMultiDiGraphView g2 = get_subgraph(g, split.second); - if (get_cut(g, split).size() > 0) { + if (get_edge_splits(g, split).size() > 0) { // Sequential split if (get_open_sinks(g1).size() <= get_open_sources(g2).size()) { // get_open_sinks(*g1).size() should be 1 in perfect sp graphs @@ -82,13 +110,16 @@ std::pair } } -float estimate_cost( - SubParallelComputationGraph const &g, - CostEstimator const &estimator, - std::unordered_map const &device_mapping) { +float estimate_cost(SubParallelComputationGraph const &g, + CostEstimator const &estimator, + MachineMapping const &device_mapping) { NOT_IMPLEMENTED(); } +void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { + minimize(m1, m2, OptimalCostRuntimeCmp{}); +} + struct OptimalCost { OptimalCost( SubParallelComputationGraph const &g, @@ -97,9 +128,9 @@ struct OptimalCost { optional const &source_machine_view, // assume perfect SP optional const &sink_machine_view, std::function( - PCGOperatorAttrs const &, MachineSpecification const &)> const + Operator const &, MachineSpecification const &)> const &allowed_machine_views, - std::unordered_map &cached_subgraph_costs) + OptimalCostCache &cached_subgraph_costs) : g(g), cost_estimator(cost_estimator), resource(resource), source_machine_view(source_machine_view), sink_machine_view(sink_machine_view), @@ -112,51 +143,27 @@ struct OptimalCost { optional const &source_machine_view; optional const &sink_machine_view; std::function( - PCGOperatorAttrs const &, MachineSpecification const &)> const + Operator const &, MachineSpecification const &)> const &allowed_machine_views; - std::unordered_map &cached_subgraph_costs; + OptimalCostCache &cached_subgraph_costs; - // TODO: move them out of the functor template - size_t hash_state(T const &sp_decomposition) const { - size_t h = std::hash{}(sp_decomposition); - hash_combine(h, resource); - hash_combine(h, source_machine_view); - hash_combine(h, sink_machine_view); - return h; - } + OptimalCostResult operator()(T const &t) const { + OptimalCostState state{g, resource, source_machine_view, sink_machine_view}; + optional cached_result = + cached_subgraph_costs.load(state); - optional load_result_from_cache(size_t hash_value) const { - if (contains_key(cached_subgraph_costs, hash_value)) { - return make_optional(cached_subgraph_costs.at(hash_value)); - } - return nullopt; - } - - void save_result_to_cache(size_t hash_value, - MachineMapping const &strategy) const { - assert(!contains_key(cached_subgraph_costs, hash_value)); - cached_subgraph_costs.emplace(hash_value, strategy); - } - - template - MachineMapping operator()(T const &t) const { - size_t state_hash_value = hash_state(t); - optional cached_result = - load_result_from_cache(state_hash_value); if (cached_result) { return cached_result.value(); } - MachineMapping result = this->optimal_cost(t); + OptimalCostResult result = this->optimal_cost(t); - save_result_to_cache(state_hash_value, result); + cached_subgraph_costs.save(state, result); return result; } - MachineMapping optimal_cost(Serial const &serial) const { - // return sum(vector_transform([&](variant const &t) { - // return visit(*this, t); }, serial.children)); + OptimalCostResult optimal_cost(Serial const &serial) const { auto decomposed = decompose(serial); SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; @@ -175,7 +182,7 @@ struct OptimalCost { Node const &split_point = get_only(set_union(pre_graph_sinks, post_graph_sources)); - MachineMapping optimal_result = MachineMapping::infinity(); + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { @@ -184,30 +191,30 @@ struct OptimalCost { optional post_source_mv = contains(post_graph_sources, split_point) ? make_optional(mv) : nullopt; - minimize(optimal_result, - MachineMapping::sequential_combine( - visit(OptimalCost(pre_graph, - cost_estimator, - resource, - source_machine_view, - pre_sink_mv, - allowed_machine_views, - cached_subgraph_costs), - pre_decompn), - visit(OptimalCost(post_graph, - cost_estimator, - resource, - post_source_mv, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), - post_decompn))); + minimize_runtime(optimal_result, + OptimalCostResult::sequential_combine( + visit(OptimalCost(pre_graph, + cost_estimator, + resource, + source_machine_view, + pre_sink_mv, + allowed_machine_views, + cached_subgraph_costs), + pre_decompn), + visit(OptimalCost(post_graph, + cost_estimator, + resource, + post_source_mv, + sink_machine_view, + allowed_machine_views, + cached_subgraph_costs), + post_decompn))); } return optimal_result; } - MachineMapping optimal_cost(Parallel const ¶llel) const { + OptimalCostResult optimal_cost(Parallel const ¶llel) const { auto decomposed = decompose(parallel); SerialParallelDecomposition decompn1 = decomposed.first; SerialParallelDecomposition decompn2 = decomposed.second; @@ -215,7 +222,7 @@ struct OptimalCost { auto subgraphs = apply_split(g, get_graph_split(decompn1, decompn2)); SubParallelComputationGraph g1 = subgraphs.first, g2 = subgraphs.second; - MachineMapping optimal_result = MachineMapping::sequential_combine( + OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( visit(OptimalCost(g1, cost_estimator, resource, @@ -234,64 +241,62 @@ struct OptimalCost { decompn2)); for (auto const &resource_split : get_resource_split(resource)) { - minimize(optimal_result, - MachineMapping::parallel_combine( - visit(OptimalCost(g1, - cost_estimator, - resource_split.first, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), - decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource_split.second, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), - decompn2))); + minimize_runtime(optimal_result, + OptimalCostResult::parallel_combine( + visit(OptimalCost(g1, + cost_estimator, + resource_split.first, + source_machine_view, + sink_machine_view, + allowed_machine_views, + cached_subgraph_costs), + decompn1), + visit(OptimalCost(g2, + cost_estimator, + resource_split.second, + source_machine_view, + sink_machine_view, + allowed_machine_views, + cached_subgraph_costs), + decompn2))); } return optimal_result; } - MachineMapping optimal_cost(Node const &node) const { + OptimalCostResult optimal_cost(Node const &node) const { if (source_machine_view) { assert(get_closed_sources(g).empty()); assert(contains(allowed_machine_views(g.at(node), resource), source_machine_view.value())); - std::unordered_map mv_map{ - {node, source_machine_view.value()}}; + MachineMapping mv_map{{{node, source_machine_view.value()}}}; return {estimate_cost(g, cost_estimator, mv_map), mv_map}; } else if (sink_machine_view) { assert(get_closed_sinks(g).empty()); assert(contains(allowed_machine_views(g.at(node), resource), sink_machine_view.value())); - std::unordered_map mv_map{ - {node, sink_machine_view.value()}}; + MachineMapping mv_map{{{node, sink_machine_view.value()}}}; return {estimate_cost(g, cost_estimator, mv_map), mv_map}; } else { - MachineMapping optimal_result = MachineMapping::infinity(); + OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { - std::unordered_map mv_map{{node, mv}}; - minimize(optimal_result, - {estimate_cost(g, cost_estimator, mv_map), mv_map}); + MachineMapping mv_map{{{node, mv}}}; + minimize_runtime(optimal_result, + {estimate_cost(g, cost_estimator, mv_map), mv_map}); } return optimal_result; } } }; -MachineMapping optimal_cost( - ParallelComputationGraph const &g, - std::function( - PCGOperatorAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - std::unordered_map &cached_subgraph_costs) { +OptimalCostResult + optimal_cost(ParallelComputationGraph const &g, + std::function( + Operator const &, MachineSpecification const &)> const + &allowed_machine_views, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + OptimalCostCache &cached_subgraph_costs) { return visit(OptimalCost(pcg_to_subpcg(g), cost_estimator, resources, diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ed62fd941d..f5747e2058 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -6,7 +6,7 @@ namespace FlexFlow { bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { - return lhs.machine_mapping.runtime < rhs.machine_mapping.runtime; + return lhs.runtime < rhs.runtime; } std::unordered_set @@ -29,7 +29,7 @@ Strategy std::unordered_set subs = get_all_substitutions(pcg); - std::unordered_map cached_subgraph_costs; + OptimalCostCache cached_subgraph_costs; DeduplicatedPriorityQueue, StrategyRuntimeCmp> candidates; @@ -50,20 +50,20 @@ Strategy if (StrategyRuntimeCmp(current_result, best_result)) { best_result = current_result; - } else if (current_result.machine_mapping.runtime > - best_result.machine_mapping.runtime * opt_config.alpha) { + } else if (current_result.runtime > + best_result.runtime * opt_config.alpha) { continue; } for (auto const &sub : subs) { for (auto const &new_pcg : apply_substitution(current_result.pcg, sub)) { - Strategy new_result(new_pcg, - optimal_cost(new_pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs)); - if (new_result.machine_mapping.runtime <= opt_config.threshold && + OptimalCostResult c = optimal_cost(new_pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); + Strategy new_result(new_pcg, c.machine_mapping, c.runtime); + if (new_result.runtime <= opt_config.threshold && new_result.pcg.query_nodes({}).size() <= opt_config.max_num_ops) { candidates.push(new_result); } diff --git a/lib/compiler/test/test_cost_estimator.h b/lib/compiler/test/test_cost_estimator.h new file mode 100644 index 0000000000..9a4ea56156 --- /dev/null +++ b/lib/compiler/test/test_cost_estimator.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H +#define _FLEXFLOW_TEST_COST_ESTIMATOR_H + +#include "compiler/cost_estimate.h" + +namespace FlexFlow { + +struct TestCostEstimator : public ICostEstimator { + float estimate_cost(PCGOperatorAttrs const &op, + std::vector const &inputs, + MachineView const &mv) const override { + return 0.1; + } + float estimate_cost(ParallelTensorShape const &tensor_shape, + MachineView const &src, + MachineView const &dst) const override { + return 0.1; + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h new file mode 100644 index 0000000000..b3453b014c --- /dev/null +++ b/lib/compiler/test/test_generator.h @@ -0,0 +1,161 @@ +#ifndef _FLEXFLOW_TEST_GENERATOR_H +#define _FLEXFLOW_TEST_GENERATOR_H + +#include "compiler/machine_mapping.h" +#include "compiler/sub_parallel_computation_graph.h" +#include "rapidcheck.h" + +using namespace FlexFlow; + +/* + Generates computation graphs with trivial layers and tensors, which are used + for tests focusing on graph structures. +*/ +ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { + return materialize_output_labelled_multidigraph_view( + ViewMultiDiGraphAsOutputLabelled( + g, + [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, + [](Tensor(MultiDiOutput const &)) { + return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; + })); +} + +/* + Generates parallel computation graphs with trivial layers and tensors, which + are used for tests focusing on graph structures. +*/ +ParallelComputationGraph + test_parallel_computation_graph(MultiDiGraphView const &g) { + return materialize_output_labelled_multidigraph_view( + ViewMultiDiGraphAsOutputLabelled( + g, + [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, + [](Operator(MultiDiOutput const &)) { + return ParallelTensor(ParallelTensorDims(TensorDims({})), + DataType::FLOAT); + })); +} + +rc::Gen small_integer_generator() { + return gen::inRange(1, 4); +} + +namespace rc { + +Gen serialParallelMultiDiGraph() { + return gen::map(gen::arbitrary(), + multidigraph_from_sp_decomposition); +} + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::map(serialParallelMultiDiGraph, test_computataion_graph); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::map(serialParallelMultiDiGraph, + test_parallel_computataion_graph); + } +}; + +template <> +struct Arbitrary> { + static Gen> arbitrary() { + return gen::mapcat(gen::arbitrary(), [](bool is_node) { + return is_node ? gen::arbitrary() : gen::arbitrary(); + }); + } +}; + +template <> +struct Arbitrary> { + static Gen> arbitrary() { + return gen::mapcat(gen::arbitrary(), [](bool is_node) { + return is_node ? gen::arbitrary() : gen::arbitrary(); + }); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&Serial::children, + gen::container>>( + gen::arbitrary>()))); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&Parallel::children, + gen::container>>( + gen::arbitrary>()))); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::mapcat(gen::arbitrary(), [](bool is_serial) { + return is_serial ? gen::construct( + gen::arbitrary()) + : gen::construct( + gen::arbitrary()); + }); + } +}; + +template +struct Arbitrary { + static Gen< + std::enable_if, Tag>::value>::type> + arbitrary() { + return gen::construct(gen::arbitrary()); + } +}; + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::apply(make_1d_machine_view, + gen::arbitrary, + gen::arbitrary, + small_integer_generator()); + } +} + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&MachineMapping::machine_views, + gen::container>( + gen::arbitrary(), gen::arbitrary()))); + } +} + +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::build( + gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), + gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), + gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), + gen::set(&MachineSpecification::inter_node_bandwidth, + gen::nonZero()), + gen::set(&MachineSpecification::intra_node_bandwidth, + gen::nonZero())); + } +} + +} // namespace rc + +#endif diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/test_machine_mapping.cc new file mode 100644 index 0000000000..4436a992d3 --- /dev/null +++ b/lib/compiler/test/test_machine_mapping.cc @@ -0,0 +1,21 @@ +#include "doctest.h" +#include "test_generator.h" + +TEST_CASE("MachineMapping::combine") { + rc::check([](MachineMapping const &m0, MachineMapping const &m1) { + RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); + + MachineMapping comb = MachineMapping::combine(m0, m1); + + RC_ASSERT(comb.machine_views.size() == + m0.machine_views.size() + m1.machine_views.size()); + RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); + RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); + }); +} + +TEST_CASE("OptimalCostResult::infinity") { + rc::check([](OptimalCostResult const &c) { + RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); + }); +} diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc new file mode 100644 index 0000000000..2d9414ba27 --- /dev/null +++ b/lib/compiler/test/test_optimal_cost.cc @@ -0,0 +1,24 @@ +#include "test_cost_estimator.h" +#include "test_generator.h" + +/* +Tests whether optimal_cost can give a valid result given random PCG, trivial +allowed machine views, trivial cost estimator and random machine specification. +*/ +TEST_CASE("optimal_cost") { + auto test_allowed_machine_views = [](Operator const &, + MachineSpecification const &) { + return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + }; + rc::check([](ParallelComputationGraph const &g, + MachineSpecification const &machine_spec) { + OptimalCostCache cached_subgraph_costs; + OptimalCostResult result = optimal_cost(g, + test_allowed_machine_views, + TestCostEstimator{}, + machine_spec, + cached_subgraph_costs); + RC_ASSERT(result.runtime > 0); + RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); + }); +} diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/test_unity_algorithm.cc new file mode 100644 index 0000000000..6a0131dd77 --- /dev/null +++ b/lib/compiler/test/test_unity_algorithm.cc @@ -0,0 +1,23 @@ +#include "compiler/unity_algorithm.h" +#include "test_cost_estimator.h" +#include "test_generator.h" + +TEST_CASE("graph_optimize") { + rc::check([](ComputationGraph const &g, + float alpha, + int budget, + float threshold, + int max_num_ops) { + Strategy s = graph_optimize( + g, + TestCostEstimator{}, + MachineSpecification{1, 1, 4, 0.1, 0.2}, + [](Operator const &, MachineSpecification const &) { + return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + }, + OptimizerConfig{alpha, budget, threshold, max_num_ops}); + RC_ASSERT(get_nodes(s.pcg).size() > 0); + RC_ASSERT(s.machine_mapping.runtime > 0); + RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); + }); +} diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 870afd0448..797d3c1758 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_BIDICT_H #define _FLEXFLOW_UTILS_BIDICT_H +#include #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index a0939f7038..dcbaced8fb 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -205,6 +205,12 @@ std::unordered_map filter_values(std::unordered_map const &m, return result; } +template +bool is_submap(std::unordered_map const &m, + std::unordered_map const &sub) { + return restrict_keys(m, keys(sub)) == sub; +} + template std::unordered_set keys(C const &c) { std::unordered_set result; diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index ee98dfb9e4..7b888cf6e9 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -10,8 +10,10 @@ namespace FlexFlow { template struct cow_ptr_t { - static_assert(is_clonable::value, - "cow_ptr_t requires the type to have a clone() method"); + // static_assert(is_clonable::value, + // "cow_ptr_t requires the type to have a clone() method"); // + // TODO: + // https://github.com/flexflow/FlexFlow/issues/909#issue-1833470024 cow_ptr_t(std::shared_ptr ptr) : ptr(std::move(ptr)) {} cow_ptr_t(std::unique_ptr ptr) : ptr(std::move(ptr)) {} diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index b0108c6e5d..85b5d3ef5c 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -21,6 +21,61 @@ struct LabelledMultiDiSubgraphView std::unordered_set const &); }; +template +struct ViewMultiDiGraphAsOutputLabelled + : public IOutputLabelledMultiDiGraphView { +public: + ViewMultiDiGraphAsOutputLabelled() = delete; + explicit ViewMultiDiGraphAsOutputLabelled( + MultiDiGraphView const &g, + std::function const &node_label, + std::function const &output_label) + : g(g), node_label(node_label), output_label(output_label) {} + + virtual std::unordered_set + query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); + } + + virtual std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const override { + return g.query_edges(q); + } + + virtual NodeLabel const &at(Node const &n) const override { + return node_label(n); + } + + virtual OutputLabel &at(MultiDiOutput const &o) override { + return output_label(o); + } + +private: + MultiDiGraphView g; + std::function node_label; + std::function output_label; +}; + +CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled); + +template +Impl materialize_output_labelled_multidigraph_view( + IOutputLabelledMultiDiGraphView const &g) { + Impl result; + for (Node const &n : get_nodes(g)) { + result.add_node_unsafe(n); + result.at(n) = g.at(n); + } + for (auto const &e : get_edges(g)) { + result.add_edge(e); + } + for (MultiDiOutput const &o : get_outputs(g)) { + result.add_output(o, g.at(o)); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/serialparallel.h b/lib/utils/include/utils/graph/serialparallel.h index e7d8bd2ca9..d68cebac7a 100644 --- a/lib/utils/include/utils/graph/serialparallel.h +++ b/lib/utils/include/utils/graph/serialparallel.h @@ -35,6 +35,20 @@ SerialParallelDecomposition std::unordered_set get_nodes(SerialParallelDecomposition const &sp); +std::unordered_map parallel_extend(MultiDiGraph &g, + MultiDiGraph const &ext); + +std::unordered_map serial_extend(MultiDiGraph &g, + MultiDiGraph const &ext); + +MultiDiGraph serial_composition(MultiDiGraph const &g1, MultiDiGraph const &g2); + +MultiDiGraph parallel_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2); + +MultiDiGraph multidigraph_from_sp_decomposition( + SerialParallelDecomposition const &sp_decomposition); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h index 514d030b6f..6002d763a6 100644 --- a/lib/utils/include/utils/test_types.h +++ b/lib/utils/include/utils/test_types.h @@ -79,11 +79,7 @@ using hash_cmp = test_type_t; namespace std { -template < - ::FlexFlow::test_types:: - capability... CAPABILITIES> //, typename = typename - // std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, - // bool>::type> +template <::FlexFlow::test_types::capability... CAPABILITIES> struct hash<::FlexFlow::test_types::test_type_t> { typename std::enable_if< ::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE, diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 8a034ad809..5484171a20 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -1,6 +1,7 @@ #include "utils/graph/serialparallel.h" #include "serialparallel_internal.h" #include "utils/containers.h" +#include "utils/graph/adjacency_multidigraph.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph.h" @@ -217,4 +218,94 @@ std::unordered_set get_nodes(Node const &node) { return {node}; } +std::unordered_map parallel_extend(MultiDiGraph &g, + MultiDiGraph const &ext) { + std::unordered_map node_map; + std::unordered_map node_port_map; + for (Node const &node : get_nodes(MultiDiGraphView(ext))) { + node_map.emplace(node, g.add_node()); + } + for (NodePort const &node_port : get_node_ports(ext)) { + node_port_map.emplace(node_port, g.add_node_port()); + } + for (MultiDiEdge const &edge : get_edges(ext)) { + g.add_edge(MultiDiEdge{node_map.at(edge.src), + node_map.at(edge.dst), + node_port_map.at(edge.srcIdx), + node_port_map.at(edge.dstIdx)}); + } + return node_map; +} + +std::unordered_map serial_extend(MultiDiGraph &g, + MultiDiGraph const &ext) { + std::unordered_set original_sinks = get_sinks(g); + std::unordered_map node_map = parallel_extend(g, ext); + for (Node const &node1 : original_sinks) { + for (Node const &node2 : get_sources(ext)) { + g.add_edge(MultiDiEdge{ + node1, node_map.at(node2), g.add_node_port(), g.add_node_port()}); + } + } + return node_map; +} + +MultiDiGraph serial_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2) { + MultiDiGraph g = g1; + serial_extend(g, g2); + return g; +} + +MultiDiGraph parallel_composition(MultiDiGraph const &g1, + MultiDiGraph const &g2) { + MultiDiGraph g = g1; + parallel_extend(g, g2); + return g; +} + +struct MultiDiGraphFromSPDecompositionFunctor { + template + MultiDiGraph operator()(T const &t) { + return multidigraph_from_sp_decomposition(t); + } +}; + +MultiDiGraph multidigraph_from_sp_decomposition( + SerialParallelDecomposition const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition( + variant const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition( + variant const &sp_decomposition) { + return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); +} + +MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { + MultiDiGraph g = MultiDiGraph::create(); + for (variant const &child : serial.children) { + serial_extend(g, multidigraph_from_sp_decomposition(child)); + } + return g; +} + +MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { + MultiDiGraph g = MultiDiGraph::create(); + for (variant const &child : parallel.children) { + parallel_extend(g, multidigraph_from_sp_decomposition(child)); + } + return g; +} + +MultiDiGraph multidigraph_from_sp_decomposition(Node const &Node) { + MultiDiGraph g = MultiDiGraph::create(); + g.add_node(); + return g; +} + } // namespace FlexFlow