Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions lib/compiler/include/compiler/machine_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,65 @@
#include "pcg/machine_specification.h"
#include "pcg/machine_view.h"
#include "pcg/parallel_computation_graph.h"
#include "sub_parallel_computation_graph.h"

namespace FlexFlow {

struct MachineMapping {
static MachineMapping sequential_combine(MachineMapping const &s1,
MachineMapping const &s2);
static MachineMapping parallel_combine(MachineMapping const &s1,
MachineMapping const &s2);
static MachineMapping infinity();
static MachineMapping combine(MachineMapping const &, MachineMapping const &);
static bool nodes_are_disjoint(MachineMapping const &m1,
MachineMapping const &m2);

float runtime;
req<std::unordered_map<Node, MachineView>> machine_views;
};
FF_VISITABLE_STRUCT(MachineMapping, runtime, machine_views);
FF_VISITABLE_STRUCT(MachineMapping, machine_views);

struct OptimalCostState {
SerialParallelDecomposition subgraph;
MachineSpecification resource;
req<optional<MachineView>> source_machine_view, sink_machine_view;
};
FF_VISITABLE_STRUCT(OptimalCostState,
subgraph,
resource,
source_machine_view,
sink_machine_view);

struct OptimalCostResult {
static OptimalCostResult sequential_combine(OptimalCostResult const &s1,
OptimalCostResult const &s2);
static OptimalCostResult parallel_combine(OptimalCostResult const &s1,
OptimalCostResult const &s2);
static OptimalCostResult infinity();

float runtime;
MachineMapping machine_mapping;
};
FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping);

struct OptimalCostRuntimeCmp {
bool operator()(OptimalCostResult const &, OptimalCostResult const &);
};

class OptimalCostCache {
public:
OptimalCostCache() = default;

optional<OptimalCostResult> load(OptimalCostState const &) const;
void save(OptimalCostState const &, OptimalCostResult const &);

struct MachineMappingRuntimeCmp {
bool operator()(MachineMapping const &, MachineMapping const &);
private:
std::unordered_map<OptimalCostState, OptimalCostResult> cache;
};

MachineMapping optimal_cost(
ParallelComputationGraph const &g,
std::function<std::unordered_set<MachineView>(
Operator const &, MachineSpecification const &)> const
&allowed_machine_views,
CostEstimator const &cost_estimator,
MachineSpecification const &resources,
std::unordered_map<size_t, MachineMapping> &cached_subgraph_costs);
OptimalCostResult
optimal_cost(ParallelComputationGraph const &g,
std::function<std::unordered_set<MachineView>(
Operator const &, MachineSpecification const &)> const
&allowed_machine_views,
CostEstimator const &cost_estimator,
MachineSpecification const &resources,
OptimalCostCache &cached_subgraph_costs);

} // namespace FlexFlow

Expand Down
3 changes: 2 additions & 1 deletion lib/compiler/include/compiler/unity_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ struct Substitution {};
struct Strategy {
ParallelComputationGraph pcg;
MachineMapping machine_mapping;
req<float> runtime;
};
FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping);
FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime);

struct StrategyRuntimeCmp {
bool operator()(Strategy const &, Strategy const &);
Expand Down
11 changes: 10 additions & 1 deletion lib/compiler/src/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g);

template <typename T>
void minimize(T &t, T const &v) {
t = std::min(t, v);
if (v < t) {
t = v;
}
}

template <typename T, typename Compare>
void minimize(T &t, T const &v, Compare comp) {
if (comp(v, t)) {
t = v;
}
}

} // namespace FlexFlow
Expand Down
Loading