diff --git a/.github/workflows/helpers/build_libs.sh b/.github/workflows/helpers/build_target.sh similarity index 100% rename from .github/workflows/helpers/build_libs.sh rename to .github/workflows/helpers/build_target.sh diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_target.sh similarity index 100% rename from .github/workflows/helpers/test_libs.sh rename to .github/workflows/helpers/test_target.sh diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 639f4d82b5..a5ac6fd29f 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -62,71 +62,79 @@ jobs: - name: Build utils run: | - build_libs.sh utils + build_target.sh utils - name: Build op-attrs run: | - build_libs.sh op-attrs + build_target.sh op-attrs - name: Build pcg run: | - build_libs.sh pcg + build_target.sh pcg - name: Build kernels run: | - build_libs.sh kernels + build_target.sh kernels - name: Build substitutions run: | - build_libs.sh substitutions + build_target.sh substitutions - name: Build compiler run: | - build_libs.sh compiler + build_target.sh compiler - name: Build substitution-generator run: | - build_libs.sh substitution-generator + build_target.sh substitution-generator - name: Build local-execution run: | - build_libs.sh local-execution + build_target.sh local-execution - name: Build models run: | - build_libs.sh models + build_target.sh models + + - name: Build substitution-to-dot + run: | + build_target.sh substitution-to-dot + + - name: Build export-model-arch + run: | + build_target.sh export-model-arch - name: Test utils run: | - test_libs.sh utils + test_target.sh utils - name: Test op-attrs run: | - test_libs.sh op-attrs + test_target.sh op-attrs - name: Test pcg run: | - test_libs.sh pcg + test_target.sh pcg - name: Test substitutions run: | - test_libs.sh substitutions + test_target.sh substitutions # - name: Test compiler # run: | - # test_libs.sh compiler + # test_target.sh compiler - name: Test substitution-generator run: | - test_libs.sh substitution-generator + test_target.sh substitution-generator - name: Test local-execution run: | - test_libs.sh local-execution + test_target.sh local-execution - name: Test models run: | - test_libs.sh models + test_target.sh models - name: Generate code coverage run: | diff --git a/.proj.toml b/.proj.toml index 721d212e31..5592f184ad 100644 --- a/.proj.toml +++ b/.proj.toml @@ -13,6 +13,8 @@ build_targets = [ "substitution-generator", "local-execution", "models", + "export-model-arch", + "substitution-to-dot", ] test_targets = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index a518931ac5..792126449b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" ON) option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) +option(FF_BUILD_BIN_EXPORT_MODEL_ARCH "build export-model-arch utility" ON) set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") if (FF_CUDA_ARCH STREQUAL "") diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index fcc19b33b9..1cd7068cfd 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,9 +7,13 @@ if(FF_BUILD_SUBSTITUTION_TOOL) endif() if(FF_BUILD_VISUALIZATION_TOOL) - add_subdirectory(substitutions-to-dot) + add_subdirectory(substitution-to-dot) endif() if(FF_BUILD_ARG_PARSER) add_subdirectory(arg_parser) endif() + +if(FF_BUILD_BIN_EXPORT_MODEL_ARCH) + add_subdirectory(export-model-arch) +endif() diff --git a/bin/export-model-arch/CMakeLists.txt b/bin/export-model-arch/CMakeLists.txt new file mode 100644 index 0000000000..b931668594 --- /dev/null +++ b/bin/export-model-arch/CMakeLists.txt @@ -0,0 +1,12 @@ +ff_add_executable( + NAME + export-model-arch + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + include/ + DEPS + utils + models + compiler +) diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml new file mode 100644 index 0000000000..efaf368bc8 --- /dev/null +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "JsonSPModelExport" +features = [ + "eq", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_computation_graph.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h", +] + +[[fields]] +name = "sp_decomposition" +type = "::FlexFlow::GenericBinarySPDecompositionTree" + +[[fields]] +name = "computation_graph" +type = "::FlexFlow::V1ComputationGraph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc new file mode 100644 index 0000000000..ccc720ed14 --- /dev/null +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -0,0 +1,208 @@ +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "utils/cli/cli_get_help_message.h" +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_parse_result.h" +#include "utils/cli/cli_spec.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" + +using namespace ::FlexFlow; + +ComputationGraph get_single_operator_computation_graph() { + ComputationGraphBuilder b; + + size_t batch_size = 8; + size_t in_channels = 16; + size_t out_channels = 12; + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + batch_size, + in_channels, + out_channels, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + InitializerAttrs kernel_initializer = + InitializerAttrs{GlorotUniformAttrs{/*seed=*/12}}; + InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}}; + tensor_guid_t output = b.dense(input, + in_channels, + Activation::RELU, + /*use_bias=*/true, + DataType::FLOAT, + kernel_initializer, + bias_initializer, + "my_example_operator"); + + return b.computation_graph; +} + +ComputationGraph get_default_transformer_computation_graph() { + TransformerConfig config = get_default_transformer_config(); + ComputationGraph cg = get_transformer_computation_graph(config); + + return cg; +} + +tl::expected + get_model_computation_graph(std::string const &model_name) { + if (model_name == "transformer") { + return get_default_transformer_computation_graph(); + } else if (model_name == "split_test") { + int batch_size = 8; + return get_split_test_computation_graph(batch_size); + } else if (model_name == "single_operator") { + return get_single_operator_computation_graph(); + } else { + return tl::unexpected(fmt::format("Unknown model name: {}", model_name)); + } +} + +tl::expected + get_sp_model_export(std::string const &model_name) { + ComputationGraph computation_graph = ({ + tl::expected result = + get_model_computation_graph(model_name); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); + }); + + ComputationGraphBinarySPDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_right_assoc_binary_sp_decomposition( + computation_graph); + if (!result.has_value()) { + return tl::unexpected("Failed to generate series-parallel decomposition " + "of computation graph."); + } + result.value(); + }); + + std::pair> v1_result = + to_v1_including_node_numbering(computation_graph); + V1ComputationGraph v1_cg = v1_result.first; + bidict layer_numbering = v1_result.second; + GenericBinarySPDecompositionTree v1_sp_decomposition = + transform(sp_decomposition.raw_tree, + [&](layer_guid_t const &l) { return layer_numbering.at_r(l); }); + + return JsonSPModelExport{ + v1_sp_decomposition, + v1_cg, + }; +} + +int main(int argc, char **argv) { + CLISpec cli = empty_cli_spec(); + + CLIArgumentKey arg_key_help = cli_add_help_flag(cli); + + CLIArgumentKey key_sp_decomposition = + cli_add_flag(cli, + CLIFlagSpec{"sp-decomposition", + std::nullopt, + "also output a series parallel decomposition of " + "the model's computation graph"}); + + CLIArgumentKey key_dot = cli_add_flag( + cli, + CLIFlagSpec{ + "dot", + std::nullopt, + "output a dot representation of the model's computation graph"}); + + CLIArgumentKey key_preprocessed_dot = cli_add_flag( + cli, + CLIFlagSpec{"preprocessed-dot", + std::nullopt, + "output a dot representation of model's computation graph " + "for preprocessed to help check series-parallel structure"}); + + std::vector model_options = { + "transformer", "split_test", "single_operator"}; + CLIArgumentKey key_model_name = cli_add_positional_argument( + cli, + CLIPositionalArgumentSpec{ + "model", model_options, "name of the model to export"}); + + assert(argc >= 1); + std::string prog_name = argv[0]; + + CLIParseResult parsed = ({ + tl::expected result = + cli_parse(cli, argc, argv); + if (!result.has_value()) { + std::string error_msg = result.error(); + std::cerr << cli_get_help_message(prog_name, cli); + std::cerr << std::endl; + std::cerr << "error: " << error_msg << std::endl; + return 1; + } + + result.value(); + }); + + bool help = cli_get_flag(parsed, arg_key_help); + if (help) { + std::cerr << cli_get_help_message(prog_name, cli); + return 1; + } + + std::string model_name = cli_get_argument(parsed, key_model_name); + bool sp_decompositition = cli_get_flag(parsed, key_sp_decomposition); + bool dot = cli_get_flag(parsed, key_dot); + bool preprocessed_dot = cli_get_flag(parsed, key_preprocessed_dot); + + auto handle_error = [](auto const &result) { + if (!result.has_value()) { + std::cerr << "error: " << result.error() << std::endl; + exit(1); + } + + return result.value(); + }; + + if (dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + std::cout << as_dot(cg) << std::endl; + return 0; + } + + if (preprocessed_dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + std::string rendered = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + + std::cout << rendered << std::endl; + return 0; + } + + nlohmann::json json_output; + if (sp_decompositition) { + JsonSPModelExport model_export = + handle_error(get_sp_model_export(model_name)); + + json_output = model_export; + } else { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + json_output = to_v1(cg); + } + std::cout << json_output.dump(2) << std::endl; + + return 0; +} diff --git a/bin/substitutions-to-dot/CMakeLists.txt b/bin/substitution-to-dot/CMakeLists.txt similarity index 100% rename from bin/substitutions-to-dot/CMakeLists.txt rename to bin/substitution-to-dot/CMakeLists.txt diff --git a/bin/substitutions-to-dot/substitution_to_dot.cc b/bin/substitution-to-dot/substitution_to_dot.cc similarity index 89% rename from bin/substitutions-to-dot/substitution_to_dot.cc rename to bin/substitution-to-dot/substitution_to_dot.cc index 49a199ddd3..1b5f715bcd 100644 --- a/bin/substitutions-to-dot/substitution_to_dot.cc +++ b/bin/substitution-to-dot/substitution_to_dot.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include "utils/dot_file.h" #include #include @@ -24,10 +24,11 @@ int main(int argc, char **argv) { std::string json_path(argv[1]); std::string rule_name(argv[2]); - RuleCollection rule_collection = load_rule_collection_from_path(json_path); + LegacyRuleCollection rule_collection = + load_rule_collection_from_path(json_path); - std::optional found = std::nullopt; - for (Rule const &r : rule_collection.rules) { + std::optional found = std::nullopt; + for (LegacyRule const &r : rule_collection.rules) { if (r.name == rule_name) { found = r; break; @@ -39,7 +40,7 @@ int main(int argc, char **argv) { return 1; } - Rule r = found.value(); + LegacyRule r = found.value(); using Node = std::tuple; @@ -82,14 +83,14 @@ int main(int argc, char **argv) { }; for (int i = 0; i < r.srcOp.size(); i++) { - Operator const &o = r.srcOp[i]; + LegacyOperator const &o = r.srcOp[i]; Node srcOpNode = {NodeType::SRC, i, 0}; { dot.add_node(srcOpNode, label_map(fmt::to_string(o.op_type), srcOpNode)); dot.add_node_to_subgraph(srcOpNode, src_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::SRC_INPUT_TENSOR, t.opId, 0}; @@ -106,14 +107,14 @@ int main(int argc, char **argv) { } } for (int j = 0; j < r.dstOp.size(); j++) { - Operator const &o = r.dstOp[j]; + LegacyOperator const &o = r.dstOp[j]; Node dstOpNode = {NodeType::DST, j, 0}; { dot.add_node(dstOpNode, label_map(fmt::to_string(o.op_type), dstOpNode)); dot.add_node_to_subgraph(dstOpNode, dst_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::DST_INPUT_TENSOR, t.opId, 0}; @@ -128,7 +129,7 @@ int main(int argc, char **argv) { } } } - for (MapOutput const &mo : r.mappedOutput) { + for (LegacyMapOutput const &mo : r.mappedOutput) { Node srcOutputNode = {NodeType::SRC_OUTPUT_TENSOR, mo.srcOpId, mo.srcTsId}; Node dstOutputNode = {NodeType::DST_OUTPUT_TENSOR, mo.dstOpId, mo.dstTsId}; { diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 1dbd16bdb1..90e100bb1b 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -149,6 +149,11 @@ function(ff_add_executable) ${FF_EXEC_NAME} ${SRC}) + target_include_directories( + ${FF_EXEC_NAME} + PRIVATE + ${FF_EXEC_PRIVATE_INCLUDE}) + target_link_libraries( ${FF_EXEC_NAME} ${FF_EXEC_DEPS}) diff --git a/lib/compiler/include/compiler/graph_utils.h b/lib/compiler/include/compiler/graph_utils.h index 1370357837..75fd369434 100644 --- a/lib/compiler/include/compiler/graph_utils.h +++ b/lib/compiler/include/compiler/graph_utils.h @@ -5,12 +5,12 @@ #include "pcg/computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); +SeriesParallelDecomposition + get_series_parallel_decomposition(ParallelComputationGraph const &pcg); ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 5d17cbb373..3774f2cd52 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -9,7 +9,8 @@ #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml index 50496f661b..036647c0b1 100644 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ b/lib/compiler/include/compiler/optimal_cost_state.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h", + "utils/graph/series_parallel/series_parallel_decomposition.dtg.h", "pcg/machine_specification.dtg.h", "pcg/machine_view.dtg.h", "utils/graph/node/node.dtg.h", @@ -21,7 +21,7 @@ includes = [ [[fields]] name = "subgraph" -type = "::FlexFlow::SerialParallelDecomposition" +type = "::FlexFlow::SeriesParallelDecomposition" [[fields]] name = "resource" @@ -33,4 +33,4 @@ type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" [[fields]] name = "frontier_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h new file mode 100644 index 0000000000..3032e3efe9 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.dtg.h" +#include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &); +ComputationGraphBinarySPDecomposition + get_left_child(ComputationGraphBinarySPDecomposition const &); +ComputationGraphBinarySPDecomposition + get_right_child(ComputationGraphBinarySPDecomposition const &); +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &); +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &); +bool is_left_associative(ComputationGraphBinarySPDecomposition const &); +bool is_right_associative(ComputationGraphBinarySPDecomposition const &); +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml new file mode 100644 index 0000000000..147b1e3acf --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h new file mode 100644 index 0000000000..e85843ed26 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc new file mode 100644 index 0000000000..63054385ac --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -0,0 +1,90 @@ +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &d) { + return get_node_type(d.raw_tree); +} + +ComputationGraphBinarySPDecomposition + get_left_child(ComputationGraphBinarySPDecomposition const &d) { + return ComputationGraphBinarySPDecomposition{ + get_left_child(d.raw_tree), + }; +} + +ComputationGraphBinarySPDecomposition + get_right_child(ComputationGraphBinarySPDecomposition const &d) { + return ComputationGraphBinarySPDecomposition{ + get_right_child(d.raw_tree), + }; +} + +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { + return require_node(d.raw_tree); +} + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + left_associative_binary_sp_tree_from_nary(sp_decomposition); + + return ComputationGraphBinarySPDecomposition{transform( + raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; +} + +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + right_associative_binary_sp_tree_from_nary(sp_decomposition); + + return ComputationGraphBinarySPDecomposition{transform( + raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; +} + +bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { + return is_binary_sp_tree_left_associative(d.raw_tree); +} + +bool is_right_associative(ComputationGraphBinarySPDecomposition const &d) { + return is_binary_sp_tree_right_associative(d.raw_tree); +} + +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &d) { + return get_leaves(d.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..184ad93f4d --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,98 @@ +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph/computation_graph_edge.h" +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &cg) { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + // dot has is incapable of rendering the number of edges in the all-to-all + // connection, so for visualization purposes we instead insert a "fake" node + // to reduce the n^2 edges to 2*n edges + DiGraph preprocessed_digraph = + materialize_digraph_view(cg.raw_graph); + Node fake_node = preprocessed_digraph.add_node(); + for (layer_guid_t const &src : weight_and_input_layers) { + preprocessed_digraph.add_edge(DirectedEdge{src.raw_node, fake_node}); + } + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + preprocessed_digraph.add_edge(DirectedEdge{fake_node, dst.raw_node}); + } + + std::function get_node_label = + [&](Node const &n) -> std::string { + if (n == fake_node) { + return "FAKE"; + } + LayerAttrs a = cg.raw_graph.at(n); + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + std::string preprocessed_dot = digraph_as_dot( + transitive_reduction(preprocessed_digraph), get_node_label); + + return preprocessed_dot; +} + +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &cg) { + + { + DiGraphView unpreprocessed_digraph = cg.raw_graph; + std::optional unpreprocessed_sp_decomposition = + get_series_parallel_decomposition(unpreprocessed_digraph); + if (unpreprocessed_sp_decomposition.has_value()) { + return unpreprocessed_sp_decomposition.value(); + } + } + + DiGraphView preprocessed_digraph = [&] { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + DiGraph digraph = materialize_digraph_view(cg.raw_graph); + for (layer_guid_t const &src : weight_and_input_layers) { + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + digraph.add_edge(DirectedEdge{src.raw_node, dst.raw_node}); + } + } + + return digraph; + }(); + + return get_series_parallel_decomposition(preprocessed_digraph); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 08db219a21..a19c5e8597 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -4,13 +4,13 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "utils/containers/without_order.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { +SeriesParallelDecomposition + get_series_parallel_decomposition(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); - // return get_serial_parallel_decomposition(pcg.raw_graph); + // return get_series_parallel_decomposition(pcg.raw_graph); } ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { @@ -126,11 +126,11 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { // } // }; -// std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { +// std::unordered_set get_nodes(SeriesParallelDecomposition const &sp) { // return std::visit(GetNodes{}, sp.raw_variant); // } -// std::unordered_set get_nodes(SerialSplit const &serial) { +// std::unordered_set get_nodes(SeriesSplit const &serial) { // return set_union( // transform(serial.children, [](std::variant const // child) { @@ -140,7 +140,7 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { // std::unordered_set get_nodes(ParallelSplit const ¶llel) { // return set_union( -// transform(parallel.children, [](std::variant const +// transform(parallel.children, [](std::variant const // child) { // return std::visit(GetNodes{}, child); // })); diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index af7756c635..fddd825109 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -8,18 +8,19 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers.h" #include "utils/containers/are_disjoint.h" -#include "utils/containers/as_vector.h" #include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps.h" +#include "utils/containers/require_no_duplicates.h" +#include "utils/containers/vector_of.h" #include "utils/exception.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" namespace FlexFlow { @@ -83,39 +84,43 @@ std::vector> } // We may replace this by having unflattened AST -std::pair - decompose(SerialSplit const &serial) { +std::pair + decompose(SeriesSplit const &serial) { if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; + return {widen(serial.children[0]), + widen(serial.children[1])}; } - SerialSplit decompn1 = serial; + SeriesSplit decompn1 = serial; decompn1.children.pop_back(); - return {SerialParallelDecomposition(decompn1), - widen(serial.children.back())}; + return {SeriesParallelDecomposition(decompn1), + widen(serial.children.back())}; } -std::pair +std::pair decompose(ParallelSplit const ¶llel) { if (parallel.children.size() == 2) { - std::vector children = - transform(as_vector(parallel.children), [&](auto const &child) { - return widen(child); + std::vector children = + transform(vector_of(parallel.children), [&](auto const &child) { + return widen(child); }); return {children[0], children[1]}; } ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); + std::variant child = *parallel.children.begin(); decompn1.children.erase(child); - return {SerialParallelDecomposition(decompn1), - widen(child)}; + return {SeriesParallelDecomposition(decompn1), + widen(child)}; } GraphSplit - get_graph_split(SerialParallelDecomposition const &pre_decomposition, - SerialParallelDecomposition const &post_decomposition) { - return GraphSplit{get_nodes(pre_decomposition), - get_nodes(post_decomposition)}; + get_graph_split(SeriesParallelDecomposition const &pre_decomposition, + SeriesParallelDecomposition const &post_decomposition) { + std::unordered_set pre_nodes = + require_no_duplicates(get_nodes(pre_decomposition)); + std::unordered_set post_nodes = + require_no_duplicates(get_nodes(post_decomposition)); + assert(are_disjoint(pre_nodes, post_nodes)); + return GraphSplit{pre_nodes, post_nodes}; } float estimate_cost(SubParallelComputationGraph const &g, @@ -181,7 +186,7 @@ struct MachineMappingSearcher { template OptimalCostResult operator()(T const &t) { - OptimalCostState state{SerialParallelDecomposition{t}, + OptimalCostState state{SeriesParallelDecomposition{t}, resource, given_machine_views, frontier_machine_views}; @@ -202,13 +207,13 @@ struct MachineMappingSearcher { OptimalCostResult optimal_cost(SubParallelComputationGraph const &g, MachineSpecification resource, - SerialParallelDecomposition const &sp_decomposition) { + SeriesParallelDecomposition const &sp_decomposition) { return std::visit(OptimalCostFunctor(this, g, resource, {}, {}), sp_decomposition.raw_variant); } OptimalCostResult optimal_cost( - SerialSplit const &serial, + SeriesSplit const &serial, SubParallelComputationGraph const &g, MachineSpecification const &resource, std::unordered_map const &given_machine_views, @@ -218,8 +223,8 @@ struct MachineMappingSearcher { // OptimalCostResult optimal_result = OptimalCostResult::infinity(); // auto decomposed = decompose(serial); - // SerialParallelDecomposition pre_decompn = decomposed.first; - // SerialParallelDecomposition post_decompn = decomposed.second; + // SeriesParallelDecomposition pre_decompn = decomposed.first; + // SeriesParallelDecomposition post_decompn = decomposed.second; // GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); // SubParallelComputationGraph pre_graph = @@ -273,8 +278,8 @@ struct MachineMappingSearcher { NOT_IMPLEMENTED(); // auto decomposed = decompose(parallel); - // SerialParallelDecomposition decompn1 = decomposed.first; - // SerialParallelDecomposition decompn2 = decomposed.second; + // SeriesParallelDecomposition decompn1 = decomposed.first; + // SeriesParallelDecomposition decompn2 = decomposed.second; // GraphSplit graph_split = get_graph_split(decompn1, decompn2); // SubParallelComputationGraph g1 = get_subgraph(g, graph_split.first), @@ -350,8 +355,8 @@ OptimalCostResult optimal_cost( CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - SerialParallelDecomposition sp_decomposition = - get_serial_parallel_decomposition(g); + SeriesParallelDecomposition sp_decomposition = + get_series_parallel_decomposition(g); SubParallelComputationGraph subpcg = pcg_to_subpcg(g); MachineMappingSearcher searcher( cost_estimator, allowed_machine_views, cached_subgraph_costs); diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index 13b1fd3b83..3399a45f0f 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -10,4 +10,5 @@ ff_add_test_executable( compiler doctest utils-test-common + models ) diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..ab537e73de --- /dev/null +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,340 @@ +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_computation_graph_series_parallel_decomposition(ComputationGraph)") { + SUBCASE("empty computation graph") { + ComputationGraph cg = make_empty_computation_graph(); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + // technically an empty graph is non-SP + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("just a single input") { + std::string input_layer_name = "my input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{input_layer.raw_node}; + + CHECK(result == correct); + } + + SUBCASE("single operator plus inputs and weights") { + std::string input_layer_name = "my input"; + std::string projection_weights_layer_name = "my projection weights"; + std::string bias_weights_layer_name = "my bias weights"; + std::string operator_name = "my operator"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/operator_name, + /*projection_name=*/projection_weights_layer_name, + /*bias_name=*/bias_weights_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + layer_guid_t projection_weights_layer = + get_layer_by_name(cg, projection_weights_layer_name); + layer_guid_t bias_weights_layer = + get_layer_by_name(cg, bias_weights_layer_name); + layer_guid_t operator_layer = get_layer_by_name(cg, operator_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ + ParallelSplit{ + input_layer.raw_node, + projection_weights_layer.raw_node, + bias_weights_layer.raw_node, + }, + operator_layer.raw_node, + }}; + + CHECK(result == correct); + } + + SUBCASE("SP without weight nodes but non-SP with weight nodes") { + // A minimal computation graph where without weights (w1 and w2) the + // computation graph is series-parallel, but with weight nodes it is not + // + // w1 input w2 + // \ / \ / + // op1 op2 + + std::string w1_name = "w1"; + std::string input_name = "input"; + std::string w2_name = "w2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op1_name, + /*projection_name=*/w1_name); + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op2_name, + /*projection_name=*/w2_name); + + return b.computation_graph; + }(); + + layer_guid_t w1 = get_layer_by_name(cg, w1_name); + layer_guid_t input = get_layer_by_name(cg, input_name); + layer_guid_t w2 = get_layer_by_name(cg, w2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ + ParallelSplit{ + w1.raw_node, + input.raw_node, + w2.raw_node, + }, + ParallelSplit{ + op1.raw_node, + op2.raw_node, + }, + }}; + } + + SUBCASE("SP with or without preprocessing, but preprocessing would SP " + "decomposition") { + // computation graph: + // + // input1 input2 + // | | + // op1 op2 + + std::string input1_name = "input1"; + std::string input2_name = "input2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + tensor_guid_t input2 = + b.create_input(input_shape, CreateGrad::YES, input2_name); + + b.relu(input1, op1_name); + b.relu(input2, op2_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t input2 = get_layer_by_name(cg, input2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{ + SeriesSplit{ + input1.raw_node, + op1.raw_node, + }, + SeriesSplit{ + input2.raw_node, + op2.raw_node, + }, + }}; + } + + SUBCASE("not SP with or without weight nodes") { + // computation graph: + // + // input1 + // / \ + // op1 op2 + // | \ | + // | \ | + // op3 op4 + + std::string input1_name = "input1"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + std::string op3_name = "op3"; + std::string op4_name = "op4"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + + tensor_guid_t op1_output = b.relu(input1, op1_name); + tensor_guid_t op2_output = b.relu(input1, op2_name); + b.relu(op1_output, op3_name); + b.add(op1_output, op2_output, op4_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + layer_guid_t op3 = get_layer_by_name(cg, op3_name); + layer_guid_t op4 = get_layer_by_name(cg, op4_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = std::nullopt; + } + + SUBCASE("real models") { + SUBCASE("split_test") { + ComputationGraph cg = + get_split_test_computation_graph(/*batch_size=*/8); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + } + } + + TEST_CASE("render_preprocessed_computation_graph_for_sp_decomposition(" + "ComputationGraph)") { + // currently there's not really a good way to test this, and its arguable + // how much its output really should be validated as its primarily for + // visualization and so there's not really a strict definition of + // correctness, so for now we just run it on some models and make sure it + // doesn't crash. Don't use this as an example. + + SUBCASE("basic single-operator model") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + b.dense(input, /*outDim=*/14); + + return b.computation_graph; + }(); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("split_test") { + ComputationGraph cg = get_split_test_computation_graph(/*batch_size=*/8); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + } +} diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index d6b8222968..9f5a768b27 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -46,7 +46,7 @@ using namespace FlexFlow; // namespace rc { // Gen serialParallelMultiDiGraph() { -// return gen::map(gen::arbitrary(), +// return gen::map(gen::arbitrary(), // multidigraph_from_sp_decomposition); // } @@ -113,12 +113,12 @@ using namespace FlexFlow; // }; // template <> -// struct Arbitrary { -// static Gen arbitrary() { +// struct Arbitrary { +// static Gen arbitrary() { // return gen::mapcat(gen::arbitrary(), [](bool is_serial) { -// return is_serial ? gen::construct( +// return is_serial ? gen::construct( // gen::arbitrary()) -// : gen::construct( +// : gen::construct( // gen::arbitrary()); // }); // } diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index d3221474c0..f523520f9f 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -16,7 +16,7 @@ class GenericTensorAccessorW { template typename data_type_enum_to_class
::type *get() const { if (this->data_type == DT) { - return static_cast *>(this->ptr); + return static_cast *>(this->ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", this->data_type, DT); @@ -47,7 +47,7 @@ class GenericTensorAccessorR { template typename data_type_enum_to_class
::type const *get() const { if (this->data_type == DT) { - return static_cast const *>(this->ptr); + return static_cast const *>(this->ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", this->data_type, DT); @@ -94,7 +94,7 @@ template typename data_type_enum_to_class
::type * get(GenericTensorAccessorW const &a) { if (a.data_type == DT) { - return static_cast *>(a.ptr); + return static_cast *>(a.ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", a.data_type, DT); @@ -102,9 +102,9 @@ typename data_type_enum_to_class
::type * } template -std::vector *> +std::vector *> get(std::vector const &accs) { - std::vector *> out; + std::vector *> out; for (auto acc : accs) { out.push_back(get
(acc)); } @@ -115,7 +115,7 @@ template typename data_type_enum_to_class
::type const * get(GenericTensorAccessorR const &a) { if (a.data_type == DT) { - return static_cast const *>(a.ptr); + return static_cast const *>(a.ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", a.data_type, DT); @@ -139,9 +139,9 @@ std::vector get_half_ptrs(std::vector const &); template -std::vector const *> +std::vector const *> get(std::vector const &accs) { - std::vector const *> out; + std::vector const *> out; for (auto acc : accs) { out.push_back(get
(acc)); } diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 5de9fae7ad..96a3b3b281 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_KERNELS_ARRAY_SHAPE_H #include "legion_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" #include "utils/visitable.h" #include diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index 575de57f09..eb5a1b8198 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -5,7 +5,6 @@ #include "kernels/allocation.h" #include "kernels/device.h" #include "kernels/ff_handle.h" -#include "op-attrs/ops/attention.h" #include namespace FlexFlow { diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 38be2118fa..bfd72647b0 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -4,7 +4,6 @@ #include "device.h" #include "kernels/allocation.h" #include "kernels/ff_handle.h" -#include "utils/visitable.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 14bb9d2cd2..52609a303f 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -3,6 +3,7 @@ #include "accessor.h" #include "kernels/cpu.h" +#include "op-attrs/datatype_value.dtg.h" #include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/src/allocation.cc b/lib/kernels/src/allocation.cc index a892e14a54..ccd88580db 100644 --- a/lib/kernels/src/allocation.cc +++ b/lib/kernels/src/allocation.cc @@ -1,4 +1,5 @@ #include "kernels/allocation.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/kernels/src/cpu/initializer_kernels.cc b/lib/kernels/src/cpu/initializer_kernels.cc index f3b4c9b8fd..91f4f46ef8 100644 --- a/lib/kernels/src/cpu/initializer_kernels.cc +++ b/lib/kernels/src/cpu/initializer_kernels.cc @@ -24,7 +24,7 @@ struct ConstantInitKernel { void operator()(GenericTensorAccessorW const &tensor, DataTypeValue value) const { auto arr = get
(tensor); - auto unwrapped_value = get>(value); + auto unwrapped_value = value.get>(); for (size_t i = 0; i < get_volume(tensor.shape); i++) { arr[i] = unwrapped_value; } diff --git a/lib/kernels/src/cuda/embedding_kernels.cu b/lib/kernels/src/cuda/embedding_kernels.cu index 371b45f760..e6a614ba70 100644 --- a/lib/kernels/src/cuda/embedding_kernels.cu +++ b/lib/kernels/src/cuda/embedding_kernels.cu @@ -358,7 +358,7 @@ struct ForwardKernel { weight.data_type == DataType::DOUBLE); if (!aggr.has_value()) { - embed_forward_no_aggr, real_type> + embed_forward_no_aggr, real_type_t> <<, real_type> + embed_forward_with_aggr, real_type_t> <<, real_type> + embed_backward_no_aggr, real_type_t> <<, real_type> + embed_backward_with_aggr, real_type_t> <<> + add_kernel> <<>>( input_grad.get
(), output_grad.get
(), num_elements); } diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 3eb9c486f2..a35d28fa8c 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -267,16 +267,16 @@ struct ForwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_forward_kernel> + elewise_scalar_unary_forward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, input.get(), output.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_forward_kernel> + elewise_unary_forward_kernel> <<>>( num_elements, op_type, input.get(), output.get()); } @@ -313,10 +313,10 @@ struct BackwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_backward_kernel> + elewise_scalar_unary_backward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, output.get(), output_grad.get(), @@ -324,7 +324,7 @@ struct BackwardKernel { input_grad.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_backward_kernel> + elewise_unary_backward_kernel> <<>>( num_elements, op_type, diff --git a/lib/kernels/src/cuda/ops/partition_kernels.cu b/lib/kernels/src/cuda/ops/partition_kernels.cu index e356f83d2a..1d07efb5fa 100644 --- a/lib/kernels/src/cuda/ops/partition_kernels.cu +++ b/lib/kernels/src/cuda/ops/partition_kernels.cu @@ -41,12 +41,12 @@ struct BackwardKernel { RepartitionPerDeviceState const &m, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { - add_kernel><<>>(input_grad.get(), - output_grad.get(), - input_grad.shape.num_elements()); + add_kernel><<>>(input_grad.get(), + output_grad.get(), + input_grad.shape.num_elements()); } }; diff --git a/lib/kernels/src/cuda/ops/reduction_kernels.cu b/lib/kernels/src/cuda/ops/reduction_kernels.cu index 992d27fe60..0c6ba7d8e3 100644 --- a/lib/kernels/src/cuda/ops/reduction_kernels.cu +++ b/lib/kernels/src/cuda/ops/reduction_kernels.cu @@ -42,7 +42,7 @@ struct ForwardKernel { size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - reduction_forward_kernel> + reduction_forward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/replicate_kernels.cu b/lib/kernels/src/cuda/ops/replicate_kernels.cu index 0c87418f58..76bfbe2658 100644 --- a/lib/kernels/src/cuda/ops/replicate_kernels.cu +++ b/lib/kernels/src/cuda/ops/replicate_kernels.cu @@ -54,7 +54,7 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - replicate_backward_kernel> + replicate_backward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/reshape_kernels.cu b/lib/kernels/src/cuda/ops/reshape_kernels.cu index c4da408952..5b7843a3a5 100644 --- a/lib/kernels/src/cuda/ops/reshape_kernels.cu +++ b/lib/kernels/src/cuda/ops/reshape_kernels.cu @@ -45,14 +45,14 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - apply_add_with_scale> + apply_add_with_scale> <<>>(input.get(), output.get(), input.shape.num_elements(), - static_cast>(alpha)); + static_cast>(alpha)); } }; diff --git a/lib/kernels/src/hip/ops/replicate_kernels.cpp b/lib/kernels/src/hip/ops/replicate_kernels.cpp index 9a5fc813c3..8d27bb1908 100644 --- a/lib/kernels/src/hip/ops/replicate_kernels.cpp +++ b/lib/kernels/src/hip/ops/replicate_kernels.cpp @@ -55,15 +55,16 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - hipLaunchKernelGGL(HIP_KERNEL_NAME(replicate_backward_kernel>), - GET_BLOCKS(total_elements), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - input.shape.num_elements(), - num_replicas); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(replicate_backward_kernel>), + GET_BLOCKS(total_elements), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + input.shape.num_elements(), + num_replicas); } } diff --git a/lib/kernels/src/hip/ops/reshape_kernels.cpp b/lib/kernels/src/hip/ops/reshape_kernels.cpp index 941495c0fd..47978a5f4a 100644 --- a/lib/kernels/src/hip/ops/reshape_kernels.cpp +++ b/lib/kernels/src/hip/ops/reshape_kernels.cpp @@ -47,7 +47,7 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), + hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), GET_BLOCKS(input.shape.num_elements()), CUDA_NUM_THREADS, 0, @@ -55,7 +55,7 @@ struct BackwardKernel { input.get(), output.get(), input.shape.num_elements(), - static_cast> alpha); + static_cast> alpha); } } diff --git a/lib/local-execution/include/local-execution/cost_estimate.h b/lib/local-execution/include/local-execution/cost_estimate.h index 33954827bd..31503e0da9 100644 --- a/lib/local-execution/include/local-execution/cost_estimate.h +++ b/lib/local-execution/include/local-execution/cost_estimate.h @@ -4,8 +4,8 @@ #include "local-execution/cost_details.dtg.h" #include "local-execution/local_training_backing.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" diff --git a/lib/local-execution/include/local-execution/legion_tensor_shape.h b/lib/local-execution/include/local-execution/legion_tensor_shape.h index f1d2ad252a..2f2ed50d41 100644 --- a/lib/local-execution/include/local-execution/legion_tensor_shape.h +++ b/lib/local-execution/include/local-execution/legion_tensor_shape.h @@ -4,8 +4,9 @@ #include "kernels/legion_dim.h" #include "op-attrs/datatype.h" #include "op-attrs/ff_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" +#include "utils/visitable.h" #include namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_slots_backing.h b/lib/local-execution/include/local-execution/local_slots_backing.h index 6a0c28e988..5b826c7022 100644 --- a/lib/local-execution/include/local-execution/local_slots_backing.h +++ b/lib/local-execution/include/local-execution/local_slots_backing.h @@ -7,6 +7,9 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/per_device_op_state.h" #include "local-execution/runtime_arg_config.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h index b398bb8cc3..6789624076 100644 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ b/lib/local-execution/include/local-execution/local_training_backing.h @@ -3,6 +3,7 @@ #include "local-execution/local_slots_backing.h" #include "local-execution/task_registry.h" +#include "pcg/computation_graph.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_arg_ref.h b/lib/local-execution/include/local-execution/op_arg_ref.h index 20d6ccb1c5..102a8d4362 100644 --- a/lib/local-execution/include/local-execution/op_arg_ref.h +++ b/lib/local-execution/include/local-execution/op_arg_ref.h @@ -5,7 +5,7 @@ #include "local-execution/device_specific.h" #include "local-execution/op_arg_ref_type.dtg.h" #include "local-execution/per_device_op_state.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h index 73a0460554..0f351c3a0e 100644 --- a/lib/local-execution/include/local-execution/op_task_invocation.h +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -13,10 +13,6 @@ #include "local-execution/slot_grad_id.dtg.h" #include "local-execution/task_id_t.dtg.h" #include "local-execution/variadic_tensor_ref.h" -#include "op-attrs/computation_graph_op_attrs.h" -#include "pcg/computation_graph.h" -#include "utils/bidict/bidict.h" -#include "utils/stack_map.h" #include #include #include diff --git a/lib/local-execution/include/local-execution/sim_environment.h b/lib/local-execution/include/local-execution/sim_environment.h index 3ba17ea3ff..7c81cba408 100644 --- a/lib/local-execution/include/local-execution/sim_environment.h +++ b/lib/local-execution/include/local-execution/sim_environment.h @@ -7,7 +7,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/task_argument_accessor.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/machine_view.h" #include diff --git a/lib/local-execution/include/local-execution/task_registry.struct.toml b/lib/local-execution/include/local-execution/task_registry.struct.toml index 308527efac..ada467a67d 100644 --- a/lib/local-execution/include/local-execution/task_registry.struct.toml +++ b/lib/local-execution/include/local-execution/task_registry.struct.toml @@ -15,6 +15,7 @@ includes = [ src_includes = [ "utils/hash/unordered_map.h", "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", ] [[fields]] diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc index b3a045bab4..bce29fafeb 100644 --- a/lib/local-execution/src/legion_tensor_shape.cc +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -1,4 +1,5 @@ #include "local-execution/legion_tensor_shape.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index d4e0467cbf..5203991f25 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -51,7 +51,7 @@ CostDetails LocalCostEstimator::estimate_cost( for (ParallelTensorShape const &input : inputs) { TensorShape tensor_shape = get_piece_shape(input); tensor_guid_t tensor_id = - cg_builder.create_tensor(tensor_shape, CreateGrad::YES); + cg_builder.create_input(tensor_shape, CreateGrad::YES); GenericTensorAccessorW tensor_backing = allocator.allocate_tensor(tensor_shape); tensor_backing_map.insert({tensor_id, tensor_backing}); diff --git a/lib/local-execution/src/local_slots_backing.cc b/lib/local-execution/src/local_slots_backing.cc index 0ec9068c6a..ac35d63c0b 100644 --- a/lib/local-execution/src/local_slots_backing.cc +++ b/lib/local-execution/src/local_slots_backing.cc @@ -1,4 +1,6 @@ #include "local-execution/local_slots_backing.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "utils/containers/contains_key.h" #include "utils/overload.h" diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index a2ee06a95a..0fdf1761e3 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -1,5 +1,6 @@ #include "local-execution/local_training_backing.h" #include "local-execution/task_signature_impl.h" +#include "pcg/computation_graph.h" #include "utils/containers/reversed.h" #include "utils/exception.h" diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc index 36a1dd708d..932b330453 100644 --- a/lib/local-execution/src/op_task_signature.cc +++ b/lib/local-execution/src/op_task_signature.cc @@ -1,4 +1,5 @@ #include "local-execution/op_task_signature.h" +#include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" #include "utils/fmt/unordered_set.h" diff --git a/lib/local-execution/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc index a52ebb8089..4ee609bd6c 100644 --- a/lib/local-execution/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -1,6 +1,7 @@ #include "element_unary.h" #include "kernels/element_unary_kernels.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/hash-utils.h" namespace FlexFlow { diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 2bd0acc222..4c01df53e9 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -2,10 +2,12 @@ #include "kernels/local_cuda_allocator.h" #include "kernels/managed_per_device_ff_handle.h" #include "local-execution/local_cost_estimator.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/computation_graph_builder.h" #include "test_utils.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_CUDA_TEST_SUITE) { TEST_CASE("Local Cost Estimator") { @@ -73,5 +75,3 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_slots_backing.cc b/lib/local-execution/test/src/test_local_slots_backing.cc index 542aa66087..1ec441fbca 100644 --- a/lib/local-execution/test/src/test_local_slots_backing.cc +++ b/lib/local-execution/test/src/test_local_slots_backing.cc @@ -1,15 +1,19 @@ -#include "doctest/doctest.h" #include "kernels/attention_kernels.h" #include "local-execution/local_cost_estimator.h" #include "local-execution/local_cpu_allocator.h" #include "local-execution/local_slots_backing.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/variant.h" +#include "test/utils/doctest/fmt/vector.h" #include "test_utils.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" +#include -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalSlotsBacking -- Attention Op") { @@ -37,11 +41,11 @@ TEST_SUITE(FF_TEST_SUITE) { // build graph ComputationGraphBuilder cg_builder; tensor_guid_t query_guid = - cg_builder.create_tensor(query_shape, CreateGrad::YES); + cg_builder.create_input(query_shape, CreateGrad::YES); tensor_guid_t key_guid = - cg_builder.create_tensor(key_shape, CreateGrad::YES); + cg_builder.create_input(key_shape, CreateGrad::YES); tensor_guid_t value_guid = - cg_builder.create_tensor(value_shape, CreateGrad::YES); + cg_builder.create_input(value_shape, CreateGrad::YES); std::string layer_name = "attn1"; tensor_guid_t output_guid = @@ -269,5 +273,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_task_arg_accessor.cc b/lib/local-execution/test/src/test_local_task_arg_accessor.cc index 0637faaf1c..f52fccb1ed 100644 --- a/lib/local-execution/test/src/test_local_task_arg_accessor.cc +++ b/lib/local-execution/test/src/test_local_task_arg_accessor.cc @@ -4,7 +4,7 @@ #include "local-execution/task_signature_impl.h" #include "utils/fmt/variant.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalTaskArgumentAccessor") { @@ -140,5 +140,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_task_registry.cc b/lib/local-execution/test/src/test_task_registry.cc index fa3b068425..e18b7ea2de 100644 --- a/lib/local-execution/test/src/test_task_registry.cc +++ b/lib/local-execution/test/src/test_task_registry.cc @@ -7,7 +7,7 @@ #include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Task Registry") { @@ -127,5 +127,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/models/CMakeLists.txt b/lib/models/CMakeLists.txt index 7dd7f48700..4f4b22ed47 100644 --- a/lib/models/CMakeLists.txt +++ b/lib/models/CMakeLists.txt @@ -11,6 +11,7 @@ ff_add_library( op-attrs utils pcg + rapidcheck ) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/models/include/models/split_test/split_test.h b/lib/models/include/models/split_test/split_test.h new file mode 100644 index 0000000000..b03e45b2d2 --- /dev/null +++ b/lib/models/include/models/split_test/split_test.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H + +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the computation graph of the old FlexFlow test model + * split_test + * + * @note This is a tiny model developed for testing the original Unity + * implementation. It is not a "real" model and has never been trained. + */ +ComputationGraph get_split_test_computation_graph(int batch_size); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/transformer.h b/lib/models/include/models/transformer/transformer.h similarity index 90% rename from lib/models/include/models/transformer.h rename to lib/models/include/models/transformer/transformer.h index e50fa37709..385100a4c9 100644 --- a/lib/models/include/models/transformer.h +++ b/lib/models/include/models/transformer/transformer.h @@ -1,7 +1,7 @@ -#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H -#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H -#include "models/transformer_config.dtg.h" +#include "models/transformer/transformer_config.dtg.h" #include "pcg/computation_graph_builder.h" namespace FlexFlow { diff --git a/lib/models/include/models/transformer_config.struct.toml b/lib/models/include/models/transformer/transformer_config.struct.toml similarity index 100% rename from lib/models/include/models/transformer_config.struct.toml rename to lib/models/include/models/transformer/transformer_config.struct.toml diff --git a/lib/models/src/models/split_test/split_test.cc b/lib/models/src/models/split_test/split_test.cc new file mode 100644 index 0000000000..118f94ec06 --- /dev/null +++ b/lib/models/src/models/split_test/split_test.cc @@ -0,0 +1,39 @@ +#include "models/split_test/split_test.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +ComputationGraph get_split_test_computation_graph(int batch_size) { + ComputationGraphBuilder cgb; + + int layer_dim1 = 256; + int layer_dim2 = 128; + int layer_dim3 = 64; + int layer_dim4 = 32; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(batch_size), + size_t_from_int(layer_dim1), + }}, + DataType::FLOAT, + }; + + tensor_guid_t t = cgb.create_input(input_shape, CreateGrad::YES); + t = cgb.dense(t, layer_dim2); + t = cgb.relu(t); + tensor_guid_t t1 = cgb.dense(t, layer_dim3); + tensor_guid_t t2 = cgb.dense(t, layer_dim3); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t1 = cgb.dense(t, layer_dim4); + t2 = cgb.dense(t, layer_dim4); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t = cgb.softmax(t); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/transformer.cc b/lib/models/src/models/transformer/transformer.cc similarity index 95% rename from lib/models/src/models/transformer.cc rename to lib/models/src/models/transformer/transformer.cc index 874cd85787..e179359940 100644 --- a/lib/models/src/models/transformer.cc +++ b/lib/models/src/models/transformer/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" namespace FlexFlow { @@ -100,7 +100,7 @@ tensor_guid_t assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention_normalized)); - tensor_guid_t mha = cgb.multihead_attention(input, + tensor_guid_t mha = cgb.multihead_attention(self_attention_normalized, encoder_output, encoder_output, config.num_features, @@ -149,11 +149,13 @@ ComputationGraph config.batch_size, config.sequence_length, config.num_features}}, DataType::FLOAT, }; - tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES, "input"); + tensor_guid_t target = + cgb.create_input(input_shape, CreateGrad::YES, "target"); tensor_guid_t encoder_output = create_transformer_encoder(cgb, config, input); tensor_guid_t decoder_output = - create_transformer_decoder(cgb, config, input, encoder_output); + create_transformer_decoder(cgb, config, target, encoder_output); tensor_guid_t out_prob = cgb.softmax(cgb.dense(decoder_output, /*outDim=*/config.vocab_size, diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer.cc index 2133e9965b..20274c4151 100644 --- a/lib/models/test/src/models/transformer.cc +++ b/lib/models/test/src/models/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" #include @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("num layers") { int result_num_layers = get_layers(result).size(); - int correct_num_layers = 317; + int correct_num_layers = 258; CHECK(result_num_layers == correct_num_layers); } } diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 6204b9ca49..5af00fb510 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -47,14 +47,7 @@ typename data_type_enum_to_class
::type cast_to(T t) { } template -using real_type = typename data_type_enum_to_class
::type; - -using DataTypeValue = std::variant, - real_type, - real_type, - real_type, - /* real_type, */ - real_type>; +using real_type_t = typename data_type_enum_to_class
::type; size_t size_of_datatype(DataType); diff --git a/lib/op-attrs/include/op-attrs/datatype_value.variant.toml b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml new file mode 100644 index 0000000000..3386e9d131 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "DataTypeValue" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +type = "float" + +[[values]] +type = "double" + +[[values]] +type = "int32_t" + +[[values]] +type = "int64_t" + +[[values]] +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index 6868ba083f..34d186e74e 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -3,8 +3,8 @@ #include "op-attrs/ff_dim.dtg.h" #include "utils/fmt/vector.h" -#include "utils/json.h" #include "utils/stack_vector.h" +#include namespace FlexFlow { @@ -202,11 +202,12 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { namespace nlohmann { template struct adl_serializer<::FlexFlow::DimOrdered> { - static ::FlexFlow::DimOrdered from_json(json const &j) { + static ::FlexFlow::DimOrdered from_json(nlohmann::json const &j) { return {j.template get>()}; } - static void to_json(json &j, ::FlexFlow::DimOrdered const &x) { + static void to_json(nlohmann::json &j, + ::FlexFlow::DimOrdered const &x) { j = std::vector{x.cbegin(), x.cend()}; } }; diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index f3dfe5d199..d39bac1bde 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" #include "utils/containers/subvec.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/optional.h" namespace FlexFlow { @@ -18,7 +18,7 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, }; return DimOrdered{ - subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; + subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } template diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h index 3a31ea511d..ae6e552243 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" namespace FlexFlow { @@ -12,7 +12,7 @@ DimOrdered> transform(DimOrdered const &d, F f) { using Out = std::invoke_result_t; - return DimOrdered{vector_transform(as_vector(d), f)}; + return DimOrdered{vector_transform(vector_of(d), f)}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h index 54554afb81..023dcfc586 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" namespace FlexFlow { @@ -11,7 +11,7 @@ template DimOrdered> zip(DimOrdered const &lhs, DimOrdered const &rhs) { return DimOrdered>{ - zip(as_vector(lhs), as_vector(rhs))}; + zip(vector_of(lhs), vector_of(rhs))}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 0a5f057578..4fd7d49234 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -5,11 +5,14 @@ #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "utils/record_formatter.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(BroadcastAttrs); +RecordFormatter as_dot(BroadcastAttrs const &); + tl::expected get_output_shape(BroadcastAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(BroadcastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml index 2fb385b64d..5bef144cd9 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -12,11 +12,12 @@ features = [ includes = [ "", "op-attrs/activation.dtg.h", - "utils/json.h", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] fields = [ diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml index 4b9c8a9f45..403bb87592 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -11,12 +11,14 @@ features = [ ] includes = [ - "utils/json.h", - "op-attrs/operator_type.h", + "op-attrs/operator_type.dtg.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index 38d5a4371e..66d6f99253 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -17,6 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index eaa34cc496..0a35a6c5ec 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -13,11 +13,13 @@ includes = [ "op-attrs/datatype.dtg.h", "op-attrs/activation.dtg.h", "op-attrs/regularizer_attrs.dtg.h", - "utils/json.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 108df58dce..14ee637f92 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -11,11 +11,6 @@ size_t &dim_at_idx(TensorShape &, ff_dim_t); size_t get_num_elements(TensorShape const &); size_t get_size_in_bytes(TensorShape const &); -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal); -std::optional - get_broadcast_target_shape(std::unordered_set const &); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index 166416cbad..054930cebd 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -1,5 +1,8 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/linear.h" +#include "utils/overload.h" namespace FlexFlow { @@ -8,4 +11,16 @@ OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } +RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { + return attrs.visit(overload{ + [](LinearAttrs const &l) { return as_dot(l); }, + [](BroadcastAttrs const &a) { return as_dot(a); }, + [&](auto const &) { + RecordFormatter r; + r << fmt::to_string(get_op_type(attrs)); + return r; + }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index bd69864aff..aa3c95f551 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -1,8 +1,26 @@ #include "op-attrs/ops/broadcast.h" #include "op-attrs/tensor_dims.h" +#include "utils/record_formatter.h" namespace FlexFlow { +RecordFormatter as_dot(BroadcastAttrs const &attrs) { + RecordFormatter r; + + auto kv = [](std::string const &label, auto const &val) { + RecordFormatter rr; + rr << label << fmt::to_string(val); + return rr; + }; + + for (int i = 0; i < num_dims(attrs.target_dims); i++) { + r << kv(fmt::format("target_dims[{}]", i), + dim_at_idx(attrs.target_dims, ff_dim_t{i})); + } + + return r; +} + tl::expected get_output_shape(BroadcastAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 73c0068826..4bce5449f4 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -4,9 +4,9 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -35,7 +35,7 @@ int total_replica_degree(ParallelTensorDims const &dims) { } int total_shard_degree(ParallelTensorDims const &dims) { - return product(transform(as_vector(dims.shard_dims), + return product(transform(vector_of(dims.shard_dims), [](ShardParallelDim const &d) { return d.degree; })); } diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index e716793a8f..ba7d6e8357 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -3,9 +3,9 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.dtg.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/integer_conversions.h" @@ -33,8 +33,8 @@ bool tensor_dims_is_broadcastable_to(TensorDims const &curr, return false; } - std::vector curr_dims = as_vector(curr.ff_ordered); - std::vector goal_dims = as_vector(goal.ff_ordered); + std::vector curr_dims = vector_of(curr.ff_ordered); + std::vector goal_dims = vector_of(goal.ff_ordered); for (auto const &[curr_dim, goal_dim] : zip(reversed(curr_dims), reversed(goal_dims))) { @@ -72,7 +72,7 @@ ParallelTensorDims DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees) { std::vector lifted = - transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), + transform(zip(vector_of(dims.ff_ordered), vector_of(shard_degrees)), [](std::pair const &p) { size_t size = p.first; int degree = p.second; diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index b604d442cb..07508e3065 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -27,35 +27,4 @@ size_t get_size_in_bytes(TensorShape const &s) { return get_num_elements(s) * size_of_datatype(s.data_type); } -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal) { - return tensor_dims_is_broadcastable_to(curr.dims, goal.dims) && - curr.data_type == goal.data_type; -} - -std::optional - get_broadcast_target_shape(std::unordered_set const &shapes) { - std::unordered_set datatypes = - transform(shapes, [](TensorShape const &s) { return s.data_type; }); - - if (datatypes.size() != 1) { - return std::nullopt; - } - - std::unordered_set shapes_dims = - transform(shapes, [](TensorShape const &s) { return s.dims; }); - - std::optional maybe_result_dims = - get_broadcast_target_dims(shapes_dims); - std::optional result = - transform(maybe_result_dims, [&](TensorDims const &result_dims) { - return TensorShape{ - result_dims, - get_only(datatypes), - }; - }); - - return result; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/datatype.cc index cc7e496c60..d45c156d59 100644 --- a/lib/op-attrs/test/src/datatype.cc +++ b/lib/op-attrs/test/src/datatype.cc @@ -1,6 +1,8 @@ #include "op-attrs/datatype.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("can_promote_datatype_from_to(DataType, DataType)") { diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/dim_ordered/slice.cc index 8640b077dc..8d5f247756 100644 --- a/lib/op-attrs/test/src/dim_ordered/slice.cc +++ b/lib/op-attrs/test/src/dim_ordered/slice.cc @@ -1,5 +1,7 @@ #include "op-attrs/dim_ordered/slice.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE( diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc index d2c758a05f..180bc2a01f 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc @@ -1,5 +1,5 @@ #include "op-attrs/dim_ordered/enumerate.h" -#include "utils/fmt/map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc index 11e09dc43f..8e3d0f1b80 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc @@ -1,6 +1,6 @@ #include "op-attrs/dim_ordered/zip.h" #include "op-attrs/ff_dim.dtg.h" -#include "utils/fmt/pair.h" +#include "test/utils/doctest/fmt/pair.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc index 17a68ccbc8..7580de24e5 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/dropout.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index b9dd66df5d..cbcebdbce1 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -1,9 +1,9 @@ #include "op-attrs/ops/layer_norm.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include "utils/fmt/optional.h" +#include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc index f6a8da016f..65a74932cb 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/softmax.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc index 25c7eb036f..60d87300c1 100644 --- a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc @@ -1,4 +1,5 @@ #include "op-attrs/tensor_dims.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc b/lib/op-attrs/test/src/op-attrs/tensor_shape.cc deleted file mode 100644 index bc715c183a..0000000000 --- a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc +++ /dev/null @@ -1,64 +0,0 @@ -#include "op-attrs/tensor_shape.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_broadcast_target_shape(std::unordered_set)") { - SUBCASE("target exists in inputs") { - DataType datatype = DataType::FLOAT; - - TensorShape s1 = TensorShape{ - TensorDims{FFOrdered{ - 1, - }}, - datatype, - }; - - TensorShape s2 = TensorShape{ - TensorDims{FFOrdered{10, 4, 3}}, - datatype, - }; - - TensorShape s3 = TensorShape{ - TensorDims{FFOrdered{ - 4, - 1, - }}, - datatype, - }; - - std::optional result = - get_broadcast_target_shape({s1, s2, s3}); - std::optional correct = s2; - - CHECK(result == correct); - } - - SUBCASE("datatypes don't match") { - TensorDims dims = TensorDims{FFOrdered{10, 4, 3}}; - - TensorShape s1 = TensorShape{ - dims, - DataType::FLOAT, - }; - - TensorShape s2 = TensorShape{ - dims, - DataType::DOUBLE, - }; - - std::optional result = get_broadcast_target_shape({s1, s2}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - - SUBCASE("inputs is empty") { - std::optional result = get_broadcast_target_shape({}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - } -} diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/ops/attention.cc index ade219a6a9..2fb804ca8c 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/ops/attention.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " diff --git a/lib/op-attrs/test/src/ops/batch_matmul.cc b/lib/op-attrs/test/src/ops/batch_matmul.cc index 3ff02ccece..56a2e3fa52 100644 --- a/lib/op-attrs/test/src/ops/batch_matmul.cc +++ b/lib/op-attrs/test/src/ops/batch_matmul.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/batch_matmul.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc index 31030ca0f9..c7395316ad 100644 --- a/lib/op-attrs/test/src/ops/cast.cc +++ b/lib/op-attrs/test/src/ops/cast.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/cast.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Cast shape inference") { diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc index ac18bbc798..bf74a072e0 100644 --- a/lib/op-attrs/test/src/ops/combine.cc +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/combine.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Combine shape inference") { diff --git a/lib/op-attrs/test/src/ops/element_binary.cc b/lib/op-attrs/test/src/ops/element_binary.cc index 0ed695eb89..b091833f10 100644 --- a/lib/op-attrs/test/src/ops/element_binary.cc +++ b/lib/op-attrs/test/src/ops/element_binary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_binary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("EWAdd shape inference") { diff --git a/lib/op-attrs/test/src/ops/element_unary.cc b/lib/op-attrs/test/src/ops/element_unary.cc index 4239782d55..94c382356e 100644 --- a/lib/op-attrs/test/src/ops/element_unary.cc +++ b/lib/op-attrs/test/src/ops/element_unary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_unary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ReLU shape inference") { diff --git a/lib/op-attrs/test/src/ops/embedding.cc b/lib/op-attrs/test/src/ops/embedding.cc index 9180f7055d..134737f6c0 100644 --- a/lib/op-attrs/test/src/ops/embedding.cc +++ b/lib/op-attrs/test/src/ops/embedding.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/embedding.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Sum embedding shape inference") { diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/ops/linear.cc index 0d23dc35df..f838ff4285 100644 --- a/lib/op-attrs/test/src/ops/linear.cc +++ b/lib/op-attrs/test/src/ops/linear.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Linear shape inference") { diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc index 59ed5bb5ee..0d1c8bdf98 100644 --- a/lib/op-attrs/test/src/ops/reduction.cc +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/reduction.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Reduction shape inference") { diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc index af28a6d471..8bc8205183 100644 --- a/lib/op-attrs/test/src/ops/repartition.cc +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/repartition.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Repartition shape inference") { diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc index a0ec40cc14..60a1018479 100644 --- a/lib/op-attrs/test/src/ops/replicate.cc +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -1,5 +1,7 @@ #include "op-attrs/ops/replicate.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Replicate shape inference") { diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc index f485b07b02..20825f5d73 100644 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -1,8 +1,8 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/json.h" #include #include +#include #include using namespace ::FlexFlow; @@ -10,16 +10,16 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { BatchNormAttrs correct = BatchNormAttrs{true}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + BatchNormAttrs result = j.get(); CHECK(result == correct); } TEST_CASE("ComputationGraphAttrs to/from json") { ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{true}}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + ComputationGraphOpAttrs result = j.get(); CHECK(result == correct); } @@ -29,8 +29,8 @@ TEST_SUITE(FF_TEST_SUITE) { /*repartition_dim=*/ff_dim_t{1}, /*repartition_degree=*/4, }}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + PCGOperatorAttrs result = j.get(); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc index 35851463bb..6e172d1e8e 100644 --- a/lib/op-attrs/test/src/test_regularizer_attrs.cc +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -1,6 +1,8 @@ #include "op-attrs/regularizer_attrs.dtg.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index e1875ca694..e6eb182740 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -10,6 +10,7 @@ ff_add_library( DEPS op-attrs utils + rapidcheck ) add_subdirectory(ffi) diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 088139a0f3..499b26af89 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #include "pcg/computation_graph.dtg.h" +#include "pcg/computation_graph/computation_graph_edge.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/layer_guid_t.dtg.h" #include "pcg/tensor_attrs.dtg.h" @@ -30,11 +31,24 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n); +std::unordered_set + get_subgraph_incoming_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_outgoing_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_successors(ComputationGraph const &, + std::unordered_set const &); + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); layer_guid_t get_layer_by_name(ComputationGraph const &cg, std::string const &name); +std::string as_dot(ComputationGraph const &); +void debug_print_dot(ComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h new file mode 100644 index 0000000000..2a9a9ee04a --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/computation_graph/computation_graph_edge.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +layer_guid_t get_computation_graph_edge_src_layer(ComputationGraphEdge const &); +layer_guid_t get_computation_graph_edge_dst_layer(ComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml new file mode 100644 index 0000000000..311c47d277 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index c641aed6a4..a35763cacc 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -159,9 +159,12 @@ struct ComputationGraphBuilder { std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &projection_initializer = + std::nullopt, std::optional const &bias_initializer = std::nullopt, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt, + std::optional const &projection_name = std::nullopt, + std::optional const &bias_name = std::nullopt); // Add a cast layer tensor_guid_t cast(tensor_guid_t const &input, DataType dtype, @@ -225,12 +228,16 @@ struct ComputationGraphBuilder { bool add_zero_attn = false, std::optional initializer = std::nullopt, std::optional const &maybe_name = std::nullopt); - tensor_guid_t create_tensor(TensorShape const &, CreateGrad); + tensor_guid_t + create_input(TensorShape const &, + CreateGrad, + std::optional const &maybe_name = std::nullopt); tensor_guid_t create_weight( TensorShape const &, - bool create_grad = true, + CreateGrad create_grad = CreateGrad::YES, std::optional const &initializer = std::nullopt, - std::optional sync_type = std::nullopt); + std::optional sync_type = std::nullopt, + std::optional const &name = std::nullopt); std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; @@ -243,9 +250,8 @@ struct ComputationGraphBuilder { private: TensorShape get_shape(tensor_guid_t const &) const; - tensor_guid_t broadcast(tensor_guid_t const &, - TensorShape const &, - std::string const &); + tensor_guid_t + broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); @@ -259,13 +265,22 @@ struct ComputationGraphBuilder { std::vector const &weights, std::vector const &outputs); + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + TensorShape const &output); + tensor_guid_t add_layer(LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, TensorShape const &output); - TensorShape get_broadcast_target_shape(std::vector const &); - TensorShape get_broadcast_target_shape(std::vector const &); + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output); + + TensorDims get_broadcast_target_dims(std::vector const &); + TensorDims get_broadcast_target_dims(std::vector const &); tensor_guid_t element_binary(OperatorType, diff --git a/lib/pcg/include/pcg/file_format/file_format.h b/lib/pcg/include/pcg/file_format/file_format.h deleted file mode 100644 index 823846754c..0000000000 --- a/lib/pcg/include/pcg/file_format/file_format.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H - -#include "graphs.h" -#include "utils/json.h" - -namespace FlexFlow { - -enum class FileFormatVersion { - V1, - UNSTABLE, -}; - -json to_json(ComputationGraph const &, FileFormatVersion); -ComputationGraph from_json(json const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/keyed_variant.h b/lib/pcg/include/pcg/file_format/keyed_variant.h index 11044de12b..5e29d8c252 100644 --- a/lib/pcg/include/pcg/file_format/keyed_variant.h +++ b/lib/pcg/include/pcg/file_format/keyed_variant.h @@ -1,10 +1,11 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H -#include "utils/json.h" +#include "utils/json/is_jsonable.h" #include "utils/sequence.h" #include "utils/strong_typedef.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -29,9 +30,9 @@ struct KeyedVariant { }; struct ToJsonFunctor { - ToJsonFunctor(json &j) : j(j) {} + ToJsonFunctor(nlohmann::json &j) : j(j) {} - json &j; + nlohmann::json &j; template void operator()(T const &t) { @@ -42,20 +43,20 @@ struct ToJsonFunctor { }; template -void to_json(json &j, KeyedVariant const &v) { +void to_json(nlohmann::json &j, KeyedVariant const &v) { static_assert(is_jsonable::value, ""); K key = static_cast(v.value.index()); j["type"] = key; - json &jj = j["value"]; + nlohmann::json &jj = j["value"]; visit(ToJsonFunctor{j["value"]}, v.value); } template struct FromJsonFunctor { - FromJsonFunctor(json const &j, int idx) : j(j), idx(idx) {} + FromJsonFunctor(nlohmann::json const &j, int idx) : j(j), idx(idx) {} - json const &j; + nlohmann::json const &j; int idx; template @@ -68,31 +69,31 @@ struct FromJsonFunctor { template std::string get_json_name(T const &t) { - return json{t}.get(); + return nlohmann::json{t}.get(); } template struct FromJsonMoveOnlyFunctor { - FromJsonMoveOnlyFunctor(json const &j, Key const &key) : j(j) {} + FromJsonMoveOnlyFunctor(nlohmann::json const &j, Key const &key) : j(j) {} - json const &j; + nlohmann::json const &j; Key const &key; template Variant operator()(std::integral_constant const &) const { - return j.get::type>(); + return j.get::type>(); } }; template -Variant from_json_moveonly(json const &j, K const &key) { +Variant from_json_moveonly(nlohmann::json const &j, K const &key) { FromJsonMoveOnlyFunctor func(j); return seq_get(func, idx, seq_count_t::value>{}); } template typename std::enable_if::value>::type - from_json(json const &j, KeyedVariant &v) { + from_json(nlohmann::json const &j, KeyedVariant &v) { K key = j.at("type").get(); std::string key_string = j.at("type").get(); @@ -100,7 +101,7 @@ typename std::enable_if::value>::type } template -KeyedVariant keyed_variant_from_json(json const &j) { +KeyedVariant keyed_variant_from_json(nlohmann::json const &j) { K key = j.at("type").get(); return KeyedVariant{ diff --git a/lib/pcg/include/pcg/file_format/v1/data_type_value.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h index 6e4e5abc54..ec3910aab3 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type_value.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type_value.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H #include "utils/fp16.h" -#include "utils/json.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h deleted file mode 100644 index 702c79c2b6..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H - -#include "pcg/computation_graph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" -#include "pcg/layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" -#include "pcg/tensor_attrs.dtg.h" -#include "utils/json.h" - -namespace FlexFlow { - -using V1ComputationGraph = V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ComputationGraph); -V1ComputationGraph to_v1(ComputationGraph const &); - -using V1ParallelComputationGraph = - V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ParallelComputationGraph); -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml index d9aade739c..c332b6b41d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1DataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -13,8 +13,13 @@ includes = [ "", "", "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", +] + +src_includes = [ "utils/fmt/vector.h", + "utils/hash/vector.h", "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h index 48203d73ae..fc9dfcef9a 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -13,8 +13,9 @@ namespace FlexFlow { template -V1LabelledDataflowGraph - to_v1(LabelledDataflowGraphView const &g) { +std::pair, bidict> + to_v1_including_node_numbering( + LabelledDataflowGraphView const &g) { bidict nodes = bidict_from_enumerating(get_nodes(g)); @@ -29,8 +30,17 @@ V1LabelledDataflowGraph [&](DataflowOutput const &o) { return g.at(o); }); }); - return V1LabelledDataflowGraph{ - node_labels, output_labels, unlabelled}; + return { + V1LabelledDataflowGraph{ + node_labels, output_labels, unlabelled}, + nodes, + }; +} + +template +V1LabelledDataflowGraph + to_v1(LabelledDataflowGraphView const &g) { + return to_v1_including_node_numbering(g).first; } } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index fd8d4c39c4..b440d0f03d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1LabelledDataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -20,6 +20,13 @@ includes = [ "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", ] +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + [[fields]] name = "node_labels" type = "std::unordered_map" @@ -31,4 +38,3 @@ type = "std::unordered_map>" [[fields]] name = "graph" type = "::FlexFlow::V1DataflowGraph" - diff --git a/lib/pcg/include/pcg/file_format/v1/v1.h b/lib/pcg/include/pcg/file_format/v1/v1.h deleted file mode 100644 index e2557af4f5..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H - -#include "graphs.h" -#include "pcg/computation_graph.h" - -namespace FlexFlow {} - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h new file mode 100644 index 0000000000..5590d6999b --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H + +#include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &); + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml new file mode 100644 index 0000000000..0d7135ec74 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h new file mode 100644 index 0000000000..aceb59f5af --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/file_format/v1/v1_parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..16be4a9561 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ParallelComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml index 12917d0989..4e3c31bd36 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -10,12 +10,7 @@ features = [ ] includes = [ - "op-attrs/datatype.h", - "utils/json.h", -] - -src_includes = [ - "utils/fmt/variant.h", + "op-attrs/datatype_value.dtg.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml index d062f6cd78..8290795174 100644 --- a/lib/pcg/include/pcg/layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -13,11 +13,11 @@ includes = [ "op-attrs/computation_graph_op_attrs.dtg.h", "utils/stack_string.h", "", - "utils/json.h" ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 60cfc426cc..4d61f24d37 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -17,6 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml index d9e6cf113b..323932fec6 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml @@ -19,6 +19,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml index c0b89cfc99..7f16e60914 100644 --- a/lib/pcg/include/pcg/tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -19,6 +19,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/src/file_format.cc b/lib/pcg/src/file_format.cc deleted file mode 100644 index bb01ac2dbf..0000000000 --- a/lib/pcg/src/file_format.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "pcg/file_format/v1/v1.h" - -namespace FlexFlow { - -/* void thing() { */ -/* static_assert(is_visitable::value, ""); */ - -/* json j; */ -/* auto g = j.get(); */ - -/* /1* IllBehaved v = j.get(); *1/ */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc deleted file mode 100644 index de8d5dddb4..0000000000 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "pcg/file_format/v1/graphs.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" -#include "utils/graph/algorithms.h" -#include "utils/integer_conversions.h" - -namespace FlexFlow { - -V1ComputationGraph to_v1(ComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index deaa440ef8..cf4b1496cf 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,11 +1,18 @@ #include "pcg/computation_graph.h" +#include "op-attrs/computation_graph_op_attrs.h" #include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" +#include "utils/record_formatter.h" namespace FlexFlow { @@ -20,6 +27,23 @@ std::unordered_set get_layers(ComputationGraph const &cg) { [&](Node const &n) { return layer_guid_t{n}; }); } +LayerAddedResult add_layer(ComputationGraph &computation_graph, + LayerAttrs const &attrs, + std::vector const &inputs, + std::vector const &outputs) { + std::vector raw_inputs = transform( + inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + + NodeAddedResult added = + computation_graph.raw_graph.add_node(attrs, raw_inputs, outputs); + + return LayerAddedResult{ + layer_guid_t{added.node}, + transform(added.outputs, + [](DataflowOutput const &o) { return tensor_guid_t{o}; }), + }; +} + TensorAttrs get_tensor_attrs(ComputationGraph const &cg, tensor_guid_t const &t) { return cg.raw_graph.at(t.raw_graph_output); @@ -39,8 +63,7 @@ std::vector topological_ordering(ComputationGraph const &cg) { std::vector reverse_topological_ordering(ComputationGraph const &cg) { - std::vector layers = - reversed>(get_topological_ordering(cg.raw_graph)); + std::vector layers = reversed(get_topological_ordering(cg.raw_graph)); return transform( layers, [&](Node const &e) -> layer_guid_t { return layer_guid_t{e}; }); } @@ -57,6 +80,47 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } +std::unordered_set get_subgraph_incoming_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_incoming_edges = + get_subgraph_incoming_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_incoming_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_outgoing_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_outgoing_edges = + get_subgraph_outgoing_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_outgoing_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_successors( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_successors = + get_subgraph_successors(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_successors, + [](Node const &n) { return layer_guid_t{n}; }); +} + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n) { return cg.raw_graph.at(n.raw_node); } @@ -70,4 +134,40 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, return get_only(found); } +std::string as_dot(ComputationGraph const &cg) { + std::function get_node_label = + [](LayerAttrs const &a) -> std::string { + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + std::function get_input_label = + [](TensorAttrs const &a) -> std::string { + RecordFormatter r; + + r << fmt::to_string(a.shape); + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + return as_dot(view_as_labelled_open_dataflow_graph(cg.raw_graph), + get_node_label, + get_input_label); +} + +void debug_print_dot(ComputationGraph const &cg) { + std::cout << as_dot(cg) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc new file mode 100644 index 0000000000..0efa0620c4 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc @@ -0,0 +1,15 @@ +#include "pcg/computation_graph/computation_graph_edge.h" + +namespace FlexFlow { + +layer_guid_t + get_computation_graph_edge_src_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.src.node}; +} + +layer_guid_t + get_computation_graph_edge_dst_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.dst.node}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 3f2feaf619..e0b6935a6d 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -15,6 +15,7 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/tensor_dims.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" @@ -26,6 +27,16 @@ namespace FlexFlow { +static TensorAttrs make_weight_attrs( + TensorShape const &shape, + std::optional const &initializer_attrs) { + return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; +} + +static TensorAttrs make_output_attrs(TensorShape const &shape) { + return TensorAttrs{shape, std::nullopt, std::nullopt, CreateGrad::YES}; +} + ComputationGraphBuilder::ComputationGraphBuilder() : computation_graph(make_empty_computation_graph()) {} @@ -33,13 +44,31 @@ TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { return get_tensor_attrs(this->computation_graph, t).shape; } -tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, - CreateGrad create_grad) { +tensor_guid_t ComputationGraphBuilder::create_input( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &maybe_name) { TensorAttrs tensor_attrs = TensorAttrs{shape, std::nullopt, std::nullopt, create_grad}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, - std::nullopt, + maybe_name, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); +} + +tensor_guid_t ComputationGraphBuilder::create_weight( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &initializer, + std::optional param_sync, + std::optional const &maybe_name) { + TensorAttrs tensor_attrs = + TensorAttrs{shape, initializer, param_sync, create_grad}; + LayerAttrs layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{InputAttrs{}}, + maybe_name, }; return this->add_layer(layer_attrs, {}, {}, tensor_attrs); @@ -98,9 +127,31 @@ std::vector ComputationGraphBuilder::add_layer( std::vector const &weights, std::vector const &outputs) { return this->add_layer( - layer, inputs, weights, transform(outputs, [](TensorShape const &s) { - return TensorAttrs{s, std::nullopt, std::nullopt, CreateGrad::YES}; - })); + layer, inputs, weights, transform(outputs, make_output_attrs)); +} + +tensor_guid_t ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output_shape) { + + TensorAttrs output_attrs = make_output_attrs(output_shape); + LayerAddedResult added = + ::FlexFlow::add_layer(this->computation_graph, + layer, + concat_vectors(inputs, weights), + {output_attrs}); + return get_only(added.outputs); +} + +tensor_guid_t + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + TensorShape const &output_shape) { + + std::vector weights = {}; + return this->add_layer(layer, inputs, weights, output_shape); } tensor_guid_t @@ -129,25 +180,28 @@ tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, } } -tensor_guid_t - ComputationGraphBuilder::broadcast(tensor_guid_t const &input, - TensorShape const &target_shape, - std::string const &name) { +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input, + TensorDims const &target_dims, + std::string const &name) { TensorShape input_shape = this->get_shape(input); - if (!tensor_shape_is_broadcastable_to(input_shape, target_shape)) { + if (input_shape.dims == target_dims) { + return input; + } + + if (!tensor_dims_is_broadcastable_to(input_shape.dims, target_dims)) { throw mk_runtime_error(fmt::format( - "Cannot broadcast input tensor of shape {} to target shape {}", - input_shape, - target_shape)); + "Cannot broadcast input tensor of dims {} to target dims {}", + input_shape.dims, + target_dims)); } - BroadcastAttrs attrs = BroadcastAttrs{target_shape.dims}; + BroadcastAttrs attrs = BroadcastAttrs{target_dims}; LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t @@ -184,7 +238,7 @@ tensor_guid_t ComputationGraphBuilder::element_unary( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::element_binary( @@ -194,18 +248,18 @@ tensor_guid_t ComputationGraphBuilder::element_binary( std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(op_type)); - TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); + TensorDims compute_dims = this->get_broadcast_target_dims({lhs, rhs}); DataType compute_type = std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); tensor_guid_t lhs_input = this->as_type( this->broadcast( - lhs, compute_shape, fmt::format("{}_inputl_broadcast", name)), + lhs, compute_dims, fmt::format("{}_inputl_broadcast", name)), compute_type, name + "_inputl_cast"); tensor_guid_t rhs_input = this->as_type( this->broadcast( - rhs, compute_shape, fmt::format("{}_inputr_broadcast", name)), + rhs, compute_dims, fmt::format("{}_inputr_broadcast", name)), compute_type, name + "_inputr_cast"); @@ -217,7 +271,7 @@ tensor_guid_t ComputationGraphBuilder::element_binary( TensorShape output_shape = throw_if_unexpected(get_output_shape( attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); - return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); + return this->add_layer(layer, {lhs_input, rhs_input}, output_shape); } tensor_guid_t @@ -359,12 +413,6 @@ tensor_guid_t return this->element_unary(OperatorType::ELU, input, std::nullopt, name); } -static TensorAttrs make_weight_attrs( - TensorShape const &shape, - std::optional const &initializer_attrs) { - return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; -} - tensor_guid_t ComputationGraphBuilder::conv2d( tensor_guid_t const &x, int outChannels, @@ -431,7 +479,7 @@ tensor_guid_t ComputationGraphBuilder::dropout( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::embedding( @@ -483,7 +531,7 @@ tensor_guid_t ComputationGraphBuilder::gather( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } /* std::vector @@ -531,7 +579,7 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -581,26 +629,26 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( output_shape); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( std::vector const &inputs) { - std::vector input_shapes = transform( - inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); + std::vector inputs_dims = transform( + inputs, [&](tensor_guid_t const &t) { return this->get_shape(t).dims; }); - return this->get_broadcast_target_shape(input_shapes); + return this->get_broadcast_target_dims(inputs_dims); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( - std::vector const &input_shapes) { - std::optional maybe_result = - ::FlexFlow::get_broadcast_target_shape(unordered_set_of(input_shapes)); +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( + std::vector const &inputs_dims) { + std::optional maybe_result = + ::FlexFlow::get_broadcast_target_dims(unordered_set_of(inputs_dims)); if (maybe_result.has_value()) { return maybe_result.value(); } else { throw mk_runtime_error(fmt::format( - "ComputationGraphBuilder::get_broadcast_target_shape failed to find " - "target tensor shape for input tensor shapes {}", - input_shapes)); + "ComputationGraphBuilder::get_broadcast_target_dims failed to find " + "target tensor dims for input tensor dims {}", + inputs_dims)); } } @@ -610,9 +658,11 @@ tensor_guid_t ComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, - std::optional const &maybe_name) { + std::optional const &maybe_name, + std::optional const &projection_name, + std::optional const &bias_name) { LinearAttrs attrs = LinearAttrs{outDim, use_bias, data_type, activation, std::nullopt}; @@ -623,15 +673,30 @@ tensor_guid_t ComputationGraphBuilder::dense( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - std::vector weights; - TensorShape kernel_shape = + std::vector weights; + + TensorShape projection_shape = throw_if_unexpected(get_kernel_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + + tensor_guid_t projection_weights = + this->create_weight(projection_shape, + CreateGrad::YES, + projection_initializer, + /*sync_type=*/std::nullopt, + projection_name); + + weights.push_back(projection_weights); if (use_bias) { TensorShape bias_shape = throw_if_unexpected(get_bias_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(bias_shape, bias_initializer)); + + tensor_guid_t bias_weights = this->create_weight(bias_shape, + CreateGrad::YES, + bias_initializer, + /*sync_type=*/std::nullopt, + bias_name); + weights.push_back(bias_weights); } return this->add_layer(layer, {input}, weights, output_shape); @@ -677,13 +742,13 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( TensorShape gamma_shape = throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); InitializerAttrs gamma_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{1}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); TensorShape beta_shape = throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); InitializerAttrs beta_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{0}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } @@ -716,7 +781,7 @@ tensor_guid_t ComputationGraphBuilder::softmax( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..975e92dfb7 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,24 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &g) { + return V1ComputationGraph{ + to_v1(g.raw_graph), + }; +} + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &cg) { + std::pair, bidict> + raw = + to_v1_including_node_numbering(cg.raw_graph); + V1ComputationGraph v1_cg = V1ComputationGraph{raw.first}; + bidict v1_node_ids = + map_values(raw.second, [](Node const &n) { return layer_guid_t{n}; }); + + return {v1_cg, v1_node_ids}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..9da58fcf6e --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,12 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return V1ParallelComputationGraph{ + to_v1(g.raw_graph), + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..8336d81bb4 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,30 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ComputationGraph") { + ComputationGraph cg = [] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 12, + 16, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + tensor_guid_t mm_output = b.dense(input, 8); + tensor_guid_t relu_output = b.relu(mm_output); + + return b.computation_graph; + }(); + + V1ComputationGraph v1_cg = to_v1(cg); + nlohmann::json j = v1_cg; + } +} diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..8ce25c4bc5 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,36 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ParallelComputationGraph") { + ParallelComputationGraph pcg = [] { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t mm_output = b.dense(input, 8); + parallel_tensor_guid_t relu_output = b.relu(mm_output); + + return b.pcg; + }(); + + V1ParallelComputationGraph v1_pcg = to_v1(pcg); + nlohmann::json j = v1_pcg; + } +} diff --git a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc index 0b75e3ae1a..703c129da4 100644 --- a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc +++ b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc @@ -1,6 +1,8 @@ #include "pcg/initializers/uniform_initializer_attrs.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 440f735e80..f46f267859 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -3,7 +3,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" -#include "test/utils/doctest.h" #include "utils/containers/count.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" @@ -12,6 +11,9 @@ #include "utils/containers/values.h" #include "utils/containers/without_nullopts.h" #include "utils/hash/pair.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { @@ -227,7 +229,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(num_replicate_attrs == 2); parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( - as_vector(items(layers)), + vector_of(items(layers)), [](std::pair const &kv) -> std::optional { if (get_op_type(kv.second) == OperatorType::CONV2D) { diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc index 936c2de00d..ff169d8312 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - tensor_guid_t input = b.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); tensor_guid_t output = b.conv2d(input, /*outChannels=*/5, /*kernelH=*/3, diff --git a/lib/pcg/test/src/test_machine_view.cc b/lib/pcg/test/src/test_machine_view.cc index 70fe958d8c..25c6e21b87 100644 --- a/lib/pcg/test/src/test_machine_view.cc +++ b/lib/pcg/test/src/test_machine_view.cc @@ -1,7 +1,7 @@ #include "pcg/machine_view.h" #include "pcg/strided_rectangle.h" #include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/pcg/test/src/test_strided_rectangle.cc b/lib/pcg/test/src/test_strided_rectangle.cc index 2fe3005b15..ac6af9fa19 100644 --- a/lib/pcg/test/src/test_strided_rectangle.cc +++ b/lib/pcg/test/src/test_strided_rectangle.cc @@ -1,6 +1,6 @@ #include "pcg/strided_rectangle.h" #include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/runtime/src/accessor.cc b/lib/runtime/src/accessor.cc index 44ad8ab40d..84573fb4aa 100644 --- a/lib/runtime/src/accessor.cc +++ b/lib/runtime/src/accessor.cc @@ -129,7 +129,7 @@ struct GetTensorPointerWOFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerWO>( + return (void *)helperGetTensorPointerWO>( region, req, fid, ctx, runtime); } }; @@ -141,7 +141,7 @@ struct GetTensorPointerROFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void const *)helperGetTensorPointerRO>( + return (void const *)helperGetTensorPointerRO>( region, req, fid, ctx, runtime); } }; @@ -153,7 +153,7 @@ struct GetTensorPointerRWFUnctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerRW>( + return (void *)helperGetTensorPointerRW>( region, req, fid, ctx, runtime); } }; diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 7df65ef361..ad36f1bc4b 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -25,6 +25,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 00032045c0..2d76352ccf 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -45,10 +45,6 @@ std::unordered_set get_subgraph_outgoing_edges( SubParallelComputationGraph const &, std::unordered_set const &); -std::unordered_set get_subgraph_incoming_edges( - SubParallelComputationGraph const &, - std::unordered_set const &); - std::unordered_set get_parallel_tensor_uses(SubParallelComputationGraph const &, open_parallel_tensor_guid_t const &); diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 26f8ff5062..a18737085a 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -1,6 +1,6 @@ #include "substitutions/operator_pattern/get_attribute.h" #include "op-attrs/get_op_type.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { @@ -364,7 +364,7 @@ std::optional get_attribute(TransposeAttrs const &p, case OperatorAttributeKey::OP_TYPE: return get_op_type(p); case OperatorAttributeKey::PERMUTATION: - return as_vector(p.perm); + return vector_of(p.perm); default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 05f21247c7..286bc69b84 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -1,7 +1,7 @@ #include "substitutions/tensor_pattern/get_attribute.h" #include "op-attrs/parallel_tensor_dims.h" -#include "utils/containers/as_vector.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -11,13 +11,13 @@ TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector sizes = - transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + transform(vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return d.size; }); return TensorAttributeValue{sizes}; } case TensorAttributeKey::DIM_DEGREES: { std::vector degrees = transform( - as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return size_t_from_int(d.degree); }); return TensorAttributeValue{degrees}; } diff --git a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc index 70e960bc73..95b61e0ef4 100644 --- a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc @@ -1,4 +1,5 @@ #include "substitutions/operator_pattern/get_attribute.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 6922798a97..4f56a76d0d 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -5,9 +5,9 @@ #include "substitutions/operator_pattern/operator_attribute_constraint.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc index 6621145d39..e4d763d9c3 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc @@ -1,10 +1,10 @@ #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/pattern_value.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 3475c10235..e0805dbfd4 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,8 +1,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index 9478195523..aeedd65f82 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -1,9 +1,6 @@ -#include "doctest/doctest.h" -#include "rapidcheck.h" #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/match_additional_criterion.h" #include "substitutions/unlabelled/pattern_matching.h" -#include "test/utils/all.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -13,6 +10,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" #include "utils/overload.h" +#include using namespace FlexFlow; diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index ae5e120fad..a0d77b9f76 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -13,7 +13,6 @@ ff_add_library( fmt json cuda - doctest ) add_subdirectory(ffi) diff --git a/lib/utils/include/utils/cli/cli_argument_key.variant.toml b/lib/utils/include/utils/cli/cli_argument_key.variant.toml new file mode 100644 index 0000000000..be118160ce --- /dev/null +++ b/lib/utils/include/utils/cli/cli_argument_key.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "CLIArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_positional_argument_key.dtg.h", + "utils/cli/cli_flag_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::CLIPositionalArgumentKey" + +[[values]] +type = "::FlexFlow::CLIFlagKey" diff --git a/lib/utils/include/utils/cli/cli_flag_key.struct.toml b/lib/utils/include/utils/cli/cli_flag_key.struct.toml new file mode 100644 index 0000000000..790a752911 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIFlagKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_flag_spec.struct.toml b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml new file mode 100644 index 0000000000..66a47de067 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "CLIFlagSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "long_flag" +type = "std::string" + +[[fields]] +name = "short_flag" +type = "std::optional" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_get_help_message.h b/lib/utils/include/utils/cli/cli_get_help_message.h new file mode 100644 index 0000000000..d51579a8e2 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_get_help_message.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H + +#include "utils/cli/cli_spec.dtg.h" + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse.h b/lib/utils/include/utils/cli/cli_parse.h new file mode 100644 index 0000000000..3c91a8423b --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H + +#include "utils/cli/cli_parse_result.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg); +tl::expected + cli_parse(CLISpec const &, std::vector const &); +tl::expected + cli_parse(CLISpec const &, int argc, char const *const *argv); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.h b/lib/utils/include/utils/cli/cli_parse_result.h new file mode 100644 index 0000000000..155caac7ae --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_parse_result.dtg.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &, CLIArgumentKey const &); +std::string cli_get_argument(CLIParseResult const &, CLIArgumentKey const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.struct.toml b/lib/utils/include/utils/cli/cli_parse_result.struct.toml new file mode 100644 index 0000000000..b63da7be14 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "CLIParseResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "utils/cli/cli_flag_key.dtg.h", + "utils/cli/cli_positional_argument_key.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "flags" +type = "std::unordered_map<::FlexFlow::CLIFlagKey, bool>" + +[[fields]] +name = "positional_arguments" +type = "std::unordered_map<::FlexFlow::CLIPositionalArgumentKey, std::string>" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml new file mode 100644 index 0000000000..d571d0deb3 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml new file mode 100644 index 0000000000..b1e74701ee --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "name" +type = "std::string" + +[[fields]] +name = "choices" +type = "std::optional>" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_spec.h b/lib/utils/include/utils/cli/cli_spec.h new file mode 100644 index 0000000000..2c0df08c55 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_flag_spec.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +CLISpec empty_cli_spec(); +std::vector cli_get_flag_keys(CLISpec const &); +CLIArgumentKey cli_add_help_flag(CLISpec &); +CLIArgumentKey cli_add_flag(CLISpec &, CLIFlagSpec const &); +CLIArgumentKey cli_add_positional_argument(CLISpec &, + CLIPositionalArgumentSpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_spec.struct.toml b/lib/utils/include/utils/cli/cli_spec.struct.toml new file mode 100644 index 0000000000..9f64f62c15 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "CLISpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_flag_spec.dtg.h", + "utils/cli/cli_positional_argument_spec.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "flags" +type = "std::vector<::FlexFlow::CLIFlagSpec>" + +[[fields]] +name = "positional_arguments" +type = "std::vector<::FlexFlow::CLIPositionalArgumentSpec>" diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 937ed51af2..20ab6ce440 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -74,9 +74,6 @@ bool are_all_same(C const &c); template std::function compare_by(F const &f); -template -typename C::value_type maximum(C const &v); - template T reversed(T const &t); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 7c0490fa2a..f60ef77cda 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -179,11 +179,6 @@ std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; } -template -typename C::value_type maximum(C const &v) { - return *std::max_element(v.begin(), v.end()); -} - template std::vector value_all(std::vector> const &v) { return transform(v, [](std::optional const &element) { diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h index 11ee8d2352..700106ea3f 100644 --- a/lib/utils/include/utils/containers/enumerate_vector.h +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H #include -#include #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/foldl1.h b/lib/utils/include/utils/containers/foldl1.h new file mode 100644 index 0000000000..f542f8cf00 --- /dev/null +++ b/lib/utils/include/utils/containers/foldl1.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldl1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldl1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.cbegin(); + T result = *it; + it++; + + for (; it != vec.cend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/foldr1.h b/lib/utils/include/utils/containers/foldr1.h new file mode 100644 index 0000000000..4a7e8e098c --- /dev/null +++ b/lib/utils/include/utils/containers/foldr1.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldr1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldr1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.crbegin(); + T result = *it; + it++; + for (; it != vec.crend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/generate_map.h b/lib/utils/include/utils/containers/generate_map.h index 1afa534a19..53b2a590c5 100644 --- a/lib/utils/include/utils/containers/generate_map.h +++ b/lib/utils/include/utils/containers/generate_map.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" #include "utils/type_traits_core.h" #include @@ -17,7 +17,7 @@ std::unordered_map generate_map(C const &c, F const &f) { static_assert(is_hashable_v, "Key type should be hashable (but is not)"); auto transformed = - vector_transform(as_vector(c), [&](K const &k) -> std::pair { + vector_transform(vector_of(c), [&](K const &k) -> std::pair { return {k, f(k)}; }); return {transformed.cbegin(), transformed.cend()}; diff --git a/lib/utils/include/utils/containers/get_first.h b/lib/utils/include/utils/containers/get_first.h index ce2a483401..a616c44c20 100644 --- a/lib/utils/include/utils/containers/get_first.h +++ b/lib/utils/include/utils/containers/get_first.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H +#include #include namespace FlexFlow { @@ -10,6 +11,11 @@ T get_first(std::unordered_set const &s) { return *s.cbegin(); } +template +T get_first(std::set const &s) { + return *s.cbegin(); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h new file mode 100644 index 0000000000..634bb61bc1 --- /dev/null +++ b/lib/utils/include/utils/containers/maximum.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H + +#include +#include + +namespace FlexFlow { + +template +std::optional maximum(C const &v) { + if (v.empty()) { + return std::nullopt; + } + + return *std::max_element(std::cbegin(v), std::cend(v)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/multiset_union.h b/lib/utils/include/utils/containers/multiset_union.h new file mode 100644 index 0000000000..6f2b2a7889 --- /dev/null +++ b/lib/utils/include/utils/containers/multiset_union.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H + +#include +#include + +namespace FlexFlow { + +template +std::unordered_multiset + multiset_union(std::unordered_multiset const &lhs, + std::unordered_multiset const &rhs) { + std::unordered_multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::multiset multiset_union(std::multiset const &lhs, + std::multiset const &rhs) { + std::multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::unordered_multiset multiset_union(C const &c) { + std::unordered_multiset result; + for (auto const &s : c) { + for (T const &element : s) { + result.insert(element); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_no_duplicates.h b/lib/utils/include/utils/containers/require_no_duplicates.h new file mode 100644 index 0000000000..0cbe361bdd --- /dev/null +++ b/lib/utils/include/utils/containers/require_no_duplicates.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H + +#include "utils/exception.h" +#include "utils/fmt/multiset.h" +#include "utils/fmt/unordered_multiset.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set + require_no_duplicates(std::unordered_multiset const &s) { + std::unordered_set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +template +std::set require_no_duplicates(std::multiset const &s) { + std::set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/reversed.h b/lib/utils/include/utils/containers/reversed.h index 621eee9519..902b247469 100644 --- a/lib/utils/include/utils/containers/reversed.h +++ b/lib/utils/include/utils/containers/reversed.h @@ -1,15 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H +#include + namespace FlexFlow { template -T reversed(T const &t) { - T r; - for (auto i = t.cend() - 1; i >= t.begin(); i--) { - r.push_back(*i); - } - return r; +std::vector reversed(std::vector const &t) { + std::vector result(std::crbegin(t), std::crend(t)); + return result; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/set_minus.h b/lib/utils/include/utils/containers/set_minus.h index 6efa2f0a84..fdd1f11995 100644 --- a/lib/utils/include/utils/containers/set_minus.h +++ b/lib/utils/include/utils/containers/set_minus.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H +#include #include namespace FlexFlow { @@ -15,6 +16,15 @@ std::unordered_set set_minus(std::unordered_set const &l, return result; } +template +std::set set_minus(std::set const &l, std::set const &r) { + std::set result = l; + for (T const &t : r) { + result.erase(t); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/set_of.h b/lib/utils/include/utils/containers/set_of.h new file mode 100644 index 0000000000..14658209aa --- /dev/null +++ b/lib/utils/include/utils/containers/set_of.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H + +#include + +namespace FlexFlow { + +template +std::set set_of(C const &c) { + std::set result; + for (T const &t : c) { + result.insert(t); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/to_uppercase.h b/lib/utils/include/utils/containers/to_uppercase.h new file mode 100644 index 0000000000..a2dc7786f9 --- /dev/null +++ b/lib/utils/include/utils/containers/to_uppercase.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H + +#include + +namespace FlexFlow { + +std::string to_uppercase(std::string const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/as_vector.h b/lib/utils/include/utils/containers/vector_of.h similarity index 54% rename from lib/utils/include/utils/containers/as_vector.h rename to lib/utils/include/utils/containers/vector_of.h index fafa1dc799..7fb903b4a8 100644 --- a/lib/utils/include/utils/containers/as_vector.h +++ b/lib/utils/include/utils/containers/vector_of.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H #include namespace FlexFlow { template -std::vector as_vector(C const &c) { +std::vector vector_of(C const &c) { std::vector result(c.cbegin(), c.cend()); return result; } diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index 21a6d28ca2..4170882ae6 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -1,9 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H -#include "fmt/format.h" #include "utils/check_fmtable.h" -#include +#include #include #include @@ -44,15 +43,4 @@ std::ostream &operator<<(std::ostream &s, tl::expected const &t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(tl::expected const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 8e186928fd..46bf9ca8fa 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -5,7 +5,6 @@ #include "utils/containers/sorted.h" #include "utils/fmt/pair.h" #include "utils/join_strings.h" -#include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h index cff150dc29..616b784aac 100644 --- a/lib/utils/include/utils/fmt/multiset.h +++ b/lib/utils/include/utils/fmt/multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -42,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 45eebc2c58..2364e49568 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_OPTIONAL_H #include "utils/check_fmtable.h" -#include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::optional const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index 6f7e6f6b52..ab5ddd4e28 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #include "utils/check_fmtable.h" -#include #include #include @@ -40,15 +39,4 @@ std::ostream &operator<<(std::ostream &s, std::pair const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::pair const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index 1f8012f240..a183d37542 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/containers/sorted.h" #include "utils/join_strings.h" -#include #include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 75bbb4cb8a..876a032fe6 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -5,7 +5,6 @@ #include "utils/fmt/pair.h" #include "utils/join_strings.h" #include -#include #include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index 41abbc925e..deb03a04d4 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -42,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 646ef0c7c5..257545af1b 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" #include "utils/type_traits_core.h" -#include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 867577f72a..06a56417c3 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H -#include #include #include @@ -33,15 +32,4 @@ std::ostream &operator<<(std::ostream &s, std::variant const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::variant const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 96526175a8..5d9ca0aeae 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -41,15 +40,4 @@ std::ostream &operator<<(std::ostream &s, std::vector const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::vector const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h new file mode 100644 index 0000000000..2ed0bc02be --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h index fc372f68aa..afc9c47c1c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h @@ -6,6 +6,9 @@ namespace FlexFlow { +std::optional + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &, std::vector const &); std::optional get_cbc_decomposition(DiGraphView const &); diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h new file mode 100644 index 0000000000..3066886e37 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &); +bool is_complete_bipartite_digraph(DiGraphView const &, + std::unordered_set const &srcs); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h new file mode 100644 index 0000000000..ee533a1180 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &, + std::function const &get_node_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h new file mode 100644 index 0000000000..87d0d3143a --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &, DirectedEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h new file mode 100644 index 0000000000..6d98c5c20d --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h new file mode 100644 index 0000000000..2c48d327c4 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h new file mode 100644 index 0000000000..c9751124c8 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h new file mode 100644 index 0000000000..db2526f973 --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H + +#include "utils/graph/node/node_source.h" +#include "utils/graph/undirected/i_undirected_graph.h" + +namespace FlexFlow { + +struct UnorderedSetUndirectedGraph final : public IUndirectedGraph { +public: + UnorderedSetUndirectedGraph(); + + Node add_node() override; + void add_node_unsafe(Node const &) override; + void remove_node_unsafe(Node const &) override; + void add_edge(UndirectedEdge const &) override; + void remove_edge(UndirectedEdge const &) override; + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(UndirectedEdgeQuery const &) const override; + + UnorderedSetUndirectedGraph *clone() const override; + +private: + UnorderedSetUndirectedGraph(NodeSource const &, + std::unordered_set const &, + std::unordered_set const &); + + NodeSource node_source; + std::unordered_set nodes; + std::unordered_set edges; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h index a1d6e9e37a..8306dad1ec 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/zip.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index 0b6f348ddf..d5c22e5d3d 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h deleted file mode 100644 index be6b9ce12c..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/optional.h" -#include -#include - -namespace FlexFlow { - -std::optional - get_serial_parallel_decomposition(DiGraphView const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h deleted file mode 100644 index 6285d7ae1f..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -std::variant - flatten_ast(std::variant const &ast); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h deleted file mode 100644 index 7d8efc96f2..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include - -namespace FlexFlow { - -std::variant internal_to_final_ast( - std::variant const &ast); -SerialParallelDecomposition - to_final_ast(std::variant const &); - -std::unordered_set get_nodes(SerialParallelDecomposition const &sp); -std::unordered_set get_nodes(SerialSplit const &); -std::unordered_set get_nodes(ParallelSplit const &); -std::unordered_set get_nodes(Node const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h new file mode 100644 index 0000000000..b1607e7a76 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include + +namespace FlexFlow { + +BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &, + BinarySPDecompositionTree const &); +BinarySPDecompositionTree + make_parallel_split(BinarySPDecompositionTree const &, + BinarySPDecompositionTree const &); +BinarySPDecompositionTree make_leaf_node(Node const &); + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); + +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..1241311150 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h new file mode 100644 index 0000000000..42d71ce54e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h @@ -0,0 +1,63 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::string format_as(GenericBinarySeriesSplit const &s) { + return fmt::format("", + get_left_child(s), + get_right_child(s)); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinarySeriesSplit const &x) { + return (s << fmt::to_string(x)); +} + +template +std::string format_as(GenericBinaryParallelSplit const &s) { + return fmt::format("", + get_left_child(s), + get_right_child(s)); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinaryParallelSplit const &x) { + return (s << fmt::to_string(x)); +} + +template +std::string format_as(GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return fmt::format("", s); + }, + [](GenericBinaryParallelSplit const &s) { + return fmt::format("", s); + }, + [](T const &t) { + return fmt::format("", t); + }, + }); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinarySPDecompositionTree const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h new file mode 100644 index 0000000000..74f5ba5d8a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h @@ -0,0 +1,155 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H + +#include +#include +#include + +namespace FlexFlow { + +template +struct GenericBinarySPDecompositionTree; + +template +struct GenericBinarySeriesSplit { +public: + GenericBinarySeriesSplit() = delete; + explicit GenericBinarySeriesSplit( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) + : left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) {} + + GenericBinarySeriesSplit(GenericBinarySeriesSplit const &) = default; + + bool operator==(GenericBinarySeriesSplit const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinarySeriesSplit const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinarySeriesSplit const &other) const { + return this->tie() < other.tie(); + } + +public: + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; + +private: + std::tuple const &, + GenericBinarySPDecompositionTree const &> + tie() const { + return std::tie(*this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct GenericBinaryParallelSplit { +public: + GenericBinaryParallelSplit() = delete; + explicit GenericBinaryParallelSplit( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) + : left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) {} + + GenericBinaryParallelSplit(GenericBinaryParallelSplit const &) = default; + + bool operator==(GenericBinaryParallelSplit const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinaryParallelSplit const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinaryParallelSplit const &other) const { + return this->tie() < other.tie(); + } + +public: + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; + +private: + std::tuple const &, + GenericBinarySPDecompositionTree const &> + tie() const { + return std::tie(*this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct GenericBinarySPDecompositionTree { +public: + GenericBinarySPDecompositionTree() = delete; + explicit GenericBinarySPDecompositionTree( + GenericBinarySeriesSplit const &s) + : root{s} {} + + explicit GenericBinarySPDecompositionTree( + GenericBinaryParallelSplit const &s) + : root{s} {} + + explicit GenericBinarySPDecompositionTree(T const &t) : root{t} {} + + GenericBinarySPDecompositionTree(GenericBinarySPDecompositionTree const &) = + default; + + bool operator==(GenericBinarySPDecompositionTree const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinarySPDecompositionTree const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinarySPDecompositionTree const &other) const { + return this->tie() < other.tie(); + } + +public: + std::variant, GenericBinaryParallelSplit, T> + root; + +private: + std::tuple tie() const { + return std::tie(this->root); + } + + friend std::hash; +}; + +} // namespace FlexFlow + +// namespace rc { +// +// template <> +// struct Arbitrary<::FlexFlow::BinarySeriesSplit> { +// static Gen<::FlexFlow::BinarySeriesSplit> arbitrary(); +// }; +// +// template <> +// struct Arbitrary<::FlexFlow::GenericBinaryParallelSplit> { +// static Gen<::FlexFlow::GenericBinaryParallelSplit> arbitrary(); +// }; +// +// template <> +// struct Arbitrary<::FlexFlow::GenericBinarySPDecompositionTree> { +// static Gen<::FlexFlow::GenericBinarySPDecompositionTree> arbitrary(); +// }; +// +// } // namespace rc + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h new file mode 100644 index 0000000000..c6c1186d3d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +TT const &get(GenericBinarySPDecompositionTree const &t) { + return std::get(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h new file mode 100644 index 0000000000..51e1e20bac --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H + +#include "utils/containers/multiset_union.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](T const &t) { return std::unordered_multiset{t}; }, + [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, + [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, + }); +} + +template +std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { + return multiset_union(get_leaves(get_left_child(s)), + get_leaves(get_right_child(s))); +} + +template +std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { + return multiset_union(get_leaves(get_left_child(p)), + get_leaves(get_right_child(p))); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h new file mode 100644 index 0000000000..46a460b64e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &s) { + return *s.left_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &p) { + return *p.left_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return get_left_child(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_left_child(p); + }, + [](T const &t) -> GenericBinarySPDecompositionTree { + throw mk_runtime_error( + "get_left_child incorrectly called on leaf node"); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h new file mode 100644 index 0000000000..883acda480 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +SPDecompositionTreeNodeType + get_node_type(GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](GenericBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](GenericBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](T const &) { return SPDecompositionTreeNodeType::NODE; }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..7c6d28d7b4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { + return visit(tt, + overload{ + [](T const &t) { return 1; }, + [](GenericBinarySeriesSplit const &s) { + return get_num_tree_nodes(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_num_tree_nodes(p); + }, + }); +} + +template +int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { + return 1 + get_num_tree_nodes(get_left_child(s)) + + get_num_tree_nodes(get_right_child(s)); +} + +template +int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { + return 1 + get_num_tree_nodes(get_left_child(p)) + + get_num_tree_nodes(get_right_child(p)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h new file mode 100644 index 0000000000..f0bfba43a2 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &s) { + return *s.right_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &p) { + return *p.right_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return get_right_child(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_right_child(p); + }, + [](T const &t) -> GenericBinarySPDecompositionTree { + throw mk_runtime_error( + "get_right_child incorrectly called on leaf node"); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h new file mode 100644 index 0000000000..983dc4a572 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace std { + +template +struct hash<::FlexFlow::GenericBinarySeriesSplit> { + size_t operator()(::FlexFlow::GenericBinarySeriesSplit const &s) const { + return get_std_hash(s.tie()); + } +}; + +template +struct hash<::FlexFlow::GenericBinaryParallelSplit> { + size_t operator()(::FlexFlow::GenericBinaryParallelSplit const &s) const { + return get_std_hash(s.tie()); + } +}; + +template +struct hash<::FlexFlow::GenericBinarySPDecompositionTree> { + size_t operator()( + ::FlexFlow::GenericBinarySPDecompositionTree const &s) const { + return get_std_hash(s.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h new file mode 100644 index 0000000000..8086f38244 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +bool is_series_split(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative>(t.root); +} + +template +bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative>(t.root); +} + +template +bool is_leaf(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h new file mode 100644 index 0000000000..3ffa63753a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_left_associative( + GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](T const &) { return true; }, + [](GenericBinarySeriesSplit const &s) { + return !is_series_split(get_right_child(s)) && + is_binary_sp_tree_left_associative(get_left_child(s)) && + is_binary_sp_tree_left_associative(get_right_child(s)); + }, + [](GenericBinaryParallelSplit const &p) { + return !is_parallel_split(get_right_child(p)) && + is_binary_sp_tree_left_associative(get_left_child(p)) && + is_binary_sp_tree_left_associative(get_right_child(p)); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h new file mode 100644 index 0000000000..d88459b432 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_right_associative( + GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](T const &t) { return true; }, + [](GenericBinarySeriesSplit const &s) { + return !is_series_split(get_left_child(s)) && + is_binary_sp_tree_right_associative(get_left_child(s)) && + is_binary_sp_tree_right_associative(get_right_child(s)); + }, + [](GenericBinaryParallelSplit const &p) { + return !is_parallel_split(get_left_child(p)) && + is_binary_sp_tree_right_associative(get_left_child(p)) && + is_binary_sp_tree_right_associative(get_right_child(p)); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h new file mode 100644 index 0000000000..4f1f8266e1 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h @@ -0,0 +1,103 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::GenericBinarySeriesSplit> { + static ::FlexFlow::GenericBinarySeriesSplit from_json(json const &j) { + return ::FlexFlow::GenericBinarySeriesSplit{ + j.at("left_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + j.at("right_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + }; + } + + static void to_json(json &j, + ::FlexFlow::GenericBinarySeriesSplit const &v) { + j["__type"] = "GenericBinarySeriesSplit"; + j["left_child"] = get_left_child(v); + j["right_child"] = get_right_child(v); + } +}; + +template +struct adl_serializer<::FlexFlow::GenericBinaryParallelSplit> { + static ::FlexFlow::GenericBinaryParallelSplit from_json(json const &j) { + return ::FlexFlow::GenericBinaryParallelSplit{ + j.at("left_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + j.at("right_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + }; + } + + static void to_json(json &j, + ::FlexFlow::GenericBinaryParallelSplit const &v) { + j["__type"] = "GenericBinaryParallelSplit"; + j["left_child"] = get_left_child(v); + j["right_child"] = get_right_child(v); + } +}; + +template +struct adl_serializer<::FlexFlow::GenericBinarySPDecompositionTree> { + static ::FlexFlow::GenericBinarySPDecompositionTree + from_json(json const &j) { + std::string key = j.at("type").get(); + + if (key == "series") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get<::FlexFlow::GenericBinarySeriesSplit>(), + }; + } else if (key == "parallel") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get<::FlexFlow::GenericBinaryParallelSplit>(), + }; + } else if (key == "leaf") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get(), + }; + } else { + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown json type key: {}", key)); + } + } + + static void + to_json(json &j, + ::FlexFlow::GenericBinarySPDecompositionTree const &v) { + j["__type"] = "GenericBinarySPDecompositionTree"; + ::FlexFlow::visit( + v, + ::FlexFlow::overload{ + [&](::FlexFlow::GenericBinarySeriesSplit const &s) { + j["type"] = "series"; + j["value"] = s; + return std::monostate{}; + }, + [&](::FlexFlow::GenericBinaryParallelSplit const &p) { + j["type"] = "parallel"; + j["value"] = p; + return std::monostate{}; + }, + [&](T const &t) { + j["type"] = "leaf"; + j["value"] = t; + return std::monostate{}; + }, + }); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h new file mode 100644 index 0000000000..f55b71146a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree make_generic_binary_series_split( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{ + lhs, + rhs, + }, + }; +} + +template +GenericBinarySPDecompositionTree make_generic_binary_parallel_split( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{ + lhs, + rhs, + }, + }; +} + +template +GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(T const &t) { + return GenericBinarySPDecompositionTree{t}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h new file mode 100644 index 0000000000..a8de1ee8f8 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" + +namespace FlexFlow { + +template +GenericBinarySeriesSplit const & + require_series(GenericBinarySPDecompositionTree const &t) { + return get>(t); +} + +template +GenericBinaryParallelSplit const & + require_parallel(GenericBinarySPDecompositionTree const &t) { + return get>(t); +} + +template +T const &require_node(GenericBinarySPDecompositionTree const &t) { + return get(t); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h new file mode 100644 index 0000000000..4d7fa05960 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template > +GenericBinarySPDecompositionTree + transform(GenericBinarySPDecompositionTree const &tt, F f) { + return visit>( + tt, + overload{ + [&](GenericBinarySeriesSplit const &s) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }, + }; + }, + [&](GenericBinaryParallelSplit const &s) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }, + }; + }, + [&](T const &t) { + return GenericBinarySPDecompositionTree{ + f(t), + }; + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h new file mode 100644 index 0000000000..0d9503e59f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +Result visit(GenericBinarySPDecompositionTree const &tt, F f) { + if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative(tt.root)) { + return f(std::get(tt.root)); + } else { + throw mk_runtime_error( + "Unexpected case in visit(GenericBinarySPDecompositionTree)"); + } + + // return std::visit(tt.root, overload { + // [&](GenericBinarySeriesSplit const &s) -> Result { + // return f(s); + // }, + // [&](GenericBinaryParallelSplit const &p) -> Result { + // return f(p); + // }, + // [&](T const &t) -> Result { + // return f(t); + // }, + // }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..183ece3a89 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h new file mode 100644 index 0000000000..f5174aee56 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..e01ec0bdde --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h new file mode 100644 index 0000000000..f2a006d899 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/optional.h" +#include +#include + +namespace FlexFlow { + +std::optional + get_series_parallel_decomposition(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h b/lib/utils/include/utils/graph/series_parallel/graph_generation.h similarity index 56% rename from lib/utils/include/utils/graph/serial_parallel/graph_generation.h rename to lib/utils/include/utils/graph/series_parallel/graph_generation.h index fac9c98db2..f18fd63d24 100644 --- a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h +++ b/lib/utils/include/utils/graph/series_parallel/graph_generation.h @@ -1,23 +1,23 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H #include "utils/graph/dataflow_graph/dataflow_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext); -void serial_extend(DataflowGraph &g, DataflowGraphView const &ext); +void series_extend(DataflowGraph &g, DataflowGraphView const &ext); -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph parallel_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition); + SeriesParallelDecomposition const &sp_decomposition); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h new file mode 100644 index 0000000000..1283a6df3a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +std::variant + flatten_ast(std::variant const &ast); + +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml rename to lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml index 08f03ed12a..e7666fcd3f 100644 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml @@ -8,7 +8,7 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/split_type.dtg.h", + "utils/graph/series_parallel/split_type.dtg.h", "", "", "utils/graph/node/node.dtg.h", diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h similarity index 70% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 71cc5e3998..3fc1347ee5 100644 --- a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -1,8 +1,8 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/parallel_reduction.dtg.h" +#include "utils/graph/series_parallel/parallel_reduction.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h new file mode 100644 index 0000000000..52d2cb7236 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +std::variant internal_to_final_ast( + std::variant const &ast); +SeriesParallelDecomposition + to_final_ast(std::variant const &); + +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp); +std::unordered_multiset get_nodes(SeriesSplit const &); +std::unordered_multiset get_nodes(ParallelSplit const &); +std::unordered_multiset get_nodes(Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml similarity index 62% rename from lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml rename to lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml index f816abfbb4..921499ebd1 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "SerialParallelDecomposition" +name = "SeriesParallelDecomposition" features = [ "eq", "hash", @@ -7,12 +7,12 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_splits.h", + "utils/graph/series_parallel/series_parallel_splits.h", "utils/graph/node/node.dtg.h", ] [[values]] -type = "::FlexFlow::SerialSplit" +type = "::FlexFlow::SeriesSplit" [[values]] type = "::FlexFlow::ParallelSplit" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h similarity index 59% rename from lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h rename to lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 081137e513..18434d2b67 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #include "utils/graph/node/node.dtg.h" #include @@ -7,18 +7,18 @@ namespace FlexFlow { -struct SerialSplit; +struct SeriesSplit; struct ParallelSplit; -struct SerialSplit { +struct SeriesSplit { public: - SerialSplit() = delete; - explicit SerialSplit(std::vector> const &); - explicit SerialSplit( + SeriesSplit() = delete; + explicit SeriesSplit(std::vector> const &); + explicit SeriesSplit( std::initializer_list> const &); - bool operator==(SerialSplit const &) const; - bool operator!=(SerialSplit const &) const; + bool operator==(SeriesSplit const &) const; + bool operator!=(SeriesSplit const &) const; public: std::vector> children; @@ -28,16 +28,16 @@ struct SerialSplit { Tie tie() const; }; -std::string format_as(SerialSplit const &); -std::ostream &operator<<(std::ostream &, SerialSplit const &); +std::string format_as(SeriesSplit const &); +std::ostream &operator<<(std::ostream &, SeriesSplit const &); } // namespace FlexFlow namespace std { template <> -struct hash<::FlexFlow::SerialSplit> { - size_t operator()(::FlexFlow::SerialSplit const &) const; +struct hash<::FlexFlow::SeriesSplit> { + size_t operator()(::FlexFlow::SeriesSplit const &) const; }; } // namespace std @@ -48,15 +48,15 @@ struct ParallelSplit { public: ParallelSplit() = delete; explicit ParallelSplit( - std::unordered_set> const &); + std::unordered_multiset> const &); explicit ParallelSplit( - std::initializer_list> const &); + std::initializer_list> const &); bool operator==(ParallelSplit const &) const; bool operator!=(ParallelSplit const &) const; public: - std::unordered_set> children; + std::unordered_multiset> children; private: using Tie = std::tuple; diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h similarity index 77% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.h rename to lib/utils/include/utils/graph/series_parallel/series_reduction.h index c9bae58546..a7d53fecfc 100644 --- a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -1,9 +1,9 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/series_reduction.dtg.h" +#include "utils/graph/series_parallel/series_reduction.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml diff --git a/lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml diff --git a/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml new file mode 100644 index 0000000000..2050800cbd --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SPDecompositionTreeNodeType" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "SERIES" + +[[values]] +name = "PARALLEL" + +[[values]] +name = "NODE" diff --git a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml rename to lib/utils/include/utils/graph/series_parallel/split_type.enum.toml index 96d85f0e12..c1a1cb5978 100644 --- a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml +++ b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml @@ -8,7 +8,7 @@ features = [ ] [[values]] -name = "SERIAL" +name = "SERIES" [[values]] name = "PARALLEL" diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h new file mode 100644 index 0000000000..3e951b1db1 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h new file mode 100644 index 0000000000..bc605360d2 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &, + Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h index 1662ec6d8c..4761275031 100644 --- a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h @@ -15,7 +15,7 @@ struct IUndirectedGraph : public IUndirectedGraphView { virtual std::unordered_set query_nodes(NodeQuery const &query) const = 0; - virtual IUndirectedGraph *clone() const override = 0; + virtual IUndirectedGraph *clone() const = 0; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h index 9aa0f189ec..65939acc87 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h @@ -1,11 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H +#include "utils/graph/undirected/undirected_edge.h" #include "utils/graph/undirected/undirected_edge_query.dtg.h" namespace FlexFlow { UndirectedEdgeQuery undirected_edge_query_all(); +bool matches_edge(UndirectedEdgeQuery const &, UndirectedEdge const &); UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &, UndirectedEdgeQuery const &); diff --git a/lib/utils/include/utils/hash/multiset.h b/lib/utils/include/utils/hash/multiset.h new file mode 100644 index 0000000000..4695b89165 --- /dev/null +++ b/lib/utils/include/utils/hash/multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/unordered_multiset.h b/lib/utils/include/utils/hash/unordered_multiset.h new file mode 100644 index 0000000000..b19c76bfef --- /dev/null +++ b/lib/utils/include/utils/hash/unordered_multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::unordered_multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/json/check_is_jsonable.h b/lib/utils/include/utils/json/check_is_jsonable.h new file mode 100644 index 0000000000..41a64a1b83 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_jsonable.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSONABLE(TYPENAME) \ + static_assert(is_json_serializable::value, \ + #TYPENAME " should be json serializeable"); \ + static_assert(is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_deserializable.h b/lib/utils/include/utils/json/is_json_deserializable.h new file mode 100644 index 0000000000..9e6625428b --- /dev/null +++ b/lib/utils/include/utils/json/is_json_deserializable.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_deserializable : std::false_type {}; + +template +struct is_json_deserializable< + T, + void_t().get())>> + : std::true_type {}; + +template +inline constexpr bool is_json_deserializable_v = + is_json_deserializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_serializable.h b/lib/utils/include/utils/json/is_json_serializable.h new file mode 100644 index 0000000000..926a8037d4 --- /dev/null +++ b/lib/utils/include/utils/json/is_json_serializable.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_serializable : std::false_type {}; + +template +struct is_json_serializable< + T, + void_t() = std::declval())>> + : std::true_type {}; + +template +inline constexpr bool is_json_serializable_v = is_json_serializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_jsonable.h b/lib/utils/include/utils/json/is_jsonable.h new file mode 100644 index 0000000000..2c8c103650 --- /dev/null +++ b/lib/utils/include/utils/json/is_jsonable.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +template +struct is_jsonable + : std::conjunction, is_json_deserializable> {}; + +template +inline constexpr bool is_jsonable_v = is_jsonable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/optional.h b/lib/utils/include/utils/json/optional.h new file mode 100644 index 0000000000..c88dd24a15 --- /dev/null +++ b/lib/utils/include/utils/json/optional.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H + +#include "utils/json/is_jsonable.h" +#include +#include + +namespace nlohmann { + +template +struct adl_serializer< + std::optional, + typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { + static void to_json(json &j, std::optional const &t) { + if (t.has_value()) { + j = t.value(); + } else { + j = nullptr; + } + } + + static void from_json(json const &j, std::optional &t) { + if (j == nullptr) { + t = std::nullopt; + } else { + t = j.get(); + } + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json/variant.h b/lib/utils/include/utils/json/variant.h new file mode 100644 index 0000000000..fe2c3f3b6c --- /dev/null +++ b/lib/utils/include/utils/json/variant.h @@ -0,0 +1,89 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H + +#include "utils/json/is_jsonable.h" +#include + +namespace FlexFlow { + +struct VariantToJsonFunctor { + VariantToJsonFunctor(nlohmann::json &j) : j(j) {} + + nlohmann::json &j; + + template + void operator()(T const &t) { + static_assert(is_jsonable::value, ""); + + j = t; + } +}; + +template +void variant_to_json(json &j, std::variant const &v) { + json jval; + visit(::FlexFlow::VariantToJsonFunctor{jval}, v); + j["value"] = jval; + j["index"] = v.index(); +} + +template +std::optional variant_from_json_impl(json const &j) { + using Type = typename std::variant_alternative::type; + + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return std::nullopt; +} + +template +std::optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (std::optional &maybe : results) { + if (maybe) { + return maybe.value(); + } + } + return std::nullopt; +} + +template +std::variant variant_from_json(json const &j) { + using Variant = std::variant; + std::optional result = variant_from_json_impl( + j, std::make_index_sequence()); + if (!result.has_value()) { + throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", + j.at("index").get()); + } + return result.value(); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer, + typename std::enable_if<::FlexFlow::elements_satisfy< + ::FlexFlow::is_json_serializable, + std::variant>::value>::type> { + static void to_json(json &j, std::variant const &v) { + return ::FlexFlow::variant_to_json(j, v); + } + + static std::variant from_json(json const &j) { + return ::FlexFlow::variant_from_json(j); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json/visitable.h similarity index 52% rename from lib/utils/include/utils/json.h rename to lib/utils/include/utils/json/visitable.h index f56917e329..abc20065de 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json/visitable.h @@ -1,6 +1,9 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/json_core.h" #include "utils/optional.h" #include "utils/sequence.h" @@ -10,33 +13,6 @@ namespace FlexFlow { -template -struct is_json_serializable : std::false_type {}; - -template -struct is_json_serializable< - T, - void_t() = std::declval())>> - : std::true_type {}; - -template -struct is_json_deserializable : std::false_type {}; - -template -struct is_json_deserializable().get())>> - : std::true_type {}; - -template -struct is_jsonable - : conjunction, is_json_deserializable> {}; - -#define CHECK_IS_JSONABLE(TYPENAME) \ - static_assert(is_json_serializable::value, \ - #TYPENAME " should be json serializeable"); \ - static_assert(is_json_deserializable::value, \ - #TYPENAME " should be json deserializeable") - struct json_serialization_visitor { json_serialization_visitor() = delete; json_serialization_visitor(json &j) : j(j) {} @@ -134,66 +110,6 @@ T moveonly_visit_json_deserialize(json const &j) { return visitable_from_tuple(tuple_from_json(j)); } -struct VariantToJsonFunctor { - VariantToJsonFunctor(json &j) : j(j) {} - - json &j; - - template - void operator()(T const &t) { - static_assert(is_jsonable::value, ""); - - j = t; - } -}; - -template -void variant_to_json(json &j, std::variant const &v) { - json jval; - visit(::FlexFlow::VariantToJsonFunctor{jval}, v); - j["value"] = jval; - j["index"] = v.index(); -} - -template -std::optional variant_from_json_impl(json const &j) { - using Type = typename std::variant_alternative::type; - - if (j.at("index").get() == Idx) { - return j.at("value").get(); - } - return std::nullopt; -} - -template -std::optional variant_from_json_impl(json const &j, - std::index_sequence) { - // If there were no errors when parsing, all but one element of the array - // will be nullopt. This is because each call to variant_from_json_impl will - // have a unique index and exactly one of them will match the index in the - // json object. - std::array, sizeof...(Is)> results{ - variant_from_json_impl(j)...}; - for (std::optional &maybe : results) { - if (maybe) { - return maybe.value(); - } - } - return std::nullopt; -} - -template -std::variant variant_from_json(json const &j) { - using Variant = std::variant; - std::optional result = variant_from_json_impl( - j, std::make_index_sequence()); - if (!result.has_value()) { - throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("index").get()); - } - return result.value(); -} - } // namespace FlexFlow namespace nlohmann { @@ -231,41 +147,6 @@ struct adl_serializer< } }; -template -struct adl_serializer< - std::optional, - typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { - static void to_json(json &j, std::optional const &t) { - if (t.has_value()) { - to_json(j, t.value()); - } else { - j = nullptr; - } - } - - static void from_json(json const &j, std::optional &t) { - if (j == nullptr) { - t = std::nullopt; - } else { - t = j.get(); - } - } -}; - -template -struct adl_serializer, - typename std::enable_if<::FlexFlow::elements_satisfy< - ::FlexFlow::is_json_serializable, - std::variant>::value>::type> { - static void to_json(json &j, std::variant const &v) { - return ::FlexFlow::variant_to_json(j, v); - } - - static std::variant from_json(json const &j) { - return ::FlexFlow::variant_from_json(j); - } -}; - } // namespace nlohmann #endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3448ec4e0e..3ec165d595 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -7,6 +7,15 @@ namespace FlexFlow { +template +T or_else(std::optional const &o, F &&f) { + if (o.has_value()) { + return o.value(); + } else { + return f(); + } +} + template T const &unwrap(std::optional const &o, F const &f) { if (o.has_value()) { @@ -25,18 +34,4 @@ T const &assert_unwrap(std::optional const &o) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::map( - gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { - return m ? std::optional(std::move(*m)) : std::optional(); - }); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/include/utils/rapidcheck/optional.h b/lib/utils/include/utils/rapidcheck/optional.h new file mode 100644 index 0000000000..edb28fdb81 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/optional.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::map( + gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { + return m ? std::optional(std::move(*m)) : std::optional(); + }); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/required.h b/lib/utils/include/utils/required.h index 9cdd7918dd..d16b67ba86 100644 --- a/lib/utils/include/utils/required.h +++ b/lib/utils/include/utils/required.h @@ -1,9 +1,13 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H -#include "utils/json.h" +#include "utils/fmt/vector.h" +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/required_core.h" #include "utils/type_traits.h" +#include namespace FlexFlow { @@ -14,11 +18,11 @@ static_assert(is_list_initializable, int>::value, ""); namespace nlohmann { template struct adl_serializer<::FlexFlow::req> { - static ::FlexFlow::req from_json(json const &j) { + static ::FlexFlow::req from_json(nlohmann::json const &j) { return {j.template get()}; } - static void to_json(json &j, ::FlexFlow::req const &t) { + static void to_json(nlohmann::json &j, ::FlexFlow::req const &t) { j = static_cast(t); } }; diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 19743b8301..7a936ebd7b 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -4,9 +4,9 @@ #include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" -#include "utils/json.h" #include "utils/type_traits.h" #include +#include #include #include @@ -70,13 +70,13 @@ template using stack_string = stack_basic_string; template -void to_json(json &j, stack_string const &v) { +void to_json(nlohmann::json &j, stack_string const &v) { std::string as_string = v; j = as_string; } template -void from_json(json const &j, stack_string &v) { +void from_json(nlohmann::json const &j, stack_string &v) { std::string as_string; j.get_to(as_string); v = stack_string{as_string}; diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 1d654e3415..7a7bce7afc 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -3,12 +3,12 @@ #include "utils/hash-utils.h" #include "utils/join_strings.h" -#include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include #include #include +#include #include #include #include @@ -326,13 +326,13 @@ std::ostream &operator<<(std::ostream &s, stack_vector const &v) { } template -void to_json(json &j, stack_vector const &v) { +void to_json(nlohmann::json &j, stack_vector const &v) { std::vector as_vec(v.begin(), v.end()); j = as_vec; } template -void from_json(json const &j, stack_vector &v) { +void from_json(nlohmann::json const &j, stack_vector &v) { std::vector as_vec; j.get_to(as_vec); v = stack_vector{as_vec.begin(), as_vec.end()}; diff --git a/lib/utils/src/utils/cli/cli_get_help_message.cc b/lib/utils/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..03c53c9356 --- /dev/null +++ b/lib/utils/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,101 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/maximum.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" +#include "utils/join_strings.h" +#include + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &cli) { + auto render_pos_arg = [](CLIPositionalArgumentSpec const &pos_arg_spec) { + if (pos_arg_spec.choices.has_value()) { + return "{" + join_strings(pos_arg_spec.choices.value(), ",") + "}"; + } else { + return pos_arg_spec.name; + } + }; + + auto render_flag_option_column_key = [](CLIFlagSpec const &flag_spec) { + std::ostringstream oss; + if (flag_spec.short_flag.has_value()) { + oss << "-" << flag_spec.short_flag.value() << ", "; + } + oss << "--" << flag_spec.long_flag; + return oss.str(); + }; + + std::ostringstream oss; + + oss << "usage: " << program_name; + for (CLIFlagSpec const &flag_spec : cli.flags) { + if (flag_spec.short_flag.has_value()) { + oss << " [-" << flag_spec.short_flag.value() << "]"; + } else { + oss << " [--" << flag_spec.long_flag << "]"; + } + } + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << " " << render_pos_arg(pos_arg_spec); + } + + oss << std::endl; + + std::vector all_arg_columns = concat_vectors(std::vector{ + transform(cli.positional_arguments, render_pos_arg), + transform(cli.flags, render_flag_option_column_key), + }); + std::vector all_arg_column_widths = + transform(all_arg_columns, [](std::string const &s) { return s.size(); }); + + if (!all_arg_columns.empty()) { + int max_column_width = + std::min(int_from_size_t(maximum(all_arg_column_widths).value()), 20); + + auto render_column = [&](std::string const &key, + std::optional const &description) { + if (description.has_value()) { + if (key.size() > max_column_width) { + return " " + key + "\n" + std::string(24, ' ') + description.value(); + } else { + } + return fmt::format( + " {:<{}} {}", key, max_column_width, description.value()); + } else { + return fmt::format(" {}", key); + } + }; + + if (!cli.positional_arguments.empty()) { + oss << std::endl; + oss << "positional arguments:" << std::endl; + + if (!cli.positional_arguments.empty()) { + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << render_column(render_pos_arg(pos_arg_spec), + pos_arg_spec.description) + << std::endl; + } + } + } + + if (!cli.flags.empty()) { + oss << std::endl; + oss << "options:" << std::endl; + + for (CLIFlagSpec const &flag_spec : cli.flags) { + oss << render_column(render_flag_option_column_key(flag_spec), + flag_spec.description) + << std::endl; + } + } + } + + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse.cc b/lib/utils/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..07982c0c2d --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse.cc @@ -0,0 +1,96 @@ +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_spec.h" +#include "utils/containers/contains.h" +#include "utils/containers/enumerate.h" +#include "utils/containers/generate_map.h" + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg) { + for (auto const &[idx, flag_spec] : enumerate(cli.flags)) { + CLIFlagKey key = CLIFlagKey{idx}; + if (("--" + flag_spec.long_flag) == arg) { + return key; + } + + if (flag_spec.short_flag.has_value()) { + if ((std::string{"-"} + flag_spec.short_flag.value()) == arg) { + return key; + } + } + } + + return tl::unexpected(fmt::format("Encountered unknown flag {}", arg)); +} + +tl::expected + cli_parse(CLISpec const &cli, std::vector const &args) { + CLIParseResult result = CLIParseResult{ + generate_map(cli_get_flag_keys(cli), + [](CLIFlagKey const &) { return false; }), + {}, + }; + + int consumed_positional_args = 0; + auto parse_positional_arg = + [&](std::string const &arg) -> std::optional { + if (consumed_positional_args >= cli.positional_arguments.size()) { + return fmt::format("Too many positional arguments: expected {}", + cli.positional_arguments.size()); + } + + CLIPositionalArgumentSpec arg_spec = + cli.positional_arguments.at(consumed_positional_args); + + if (arg_spec.choices.has_value() && + !contains(arg_spec.choices.value(), arg)) { + return fmt::format( + "Invalid option for positional argument \"{}\": \"{}\"", + arg_spec.name, + arg); + } + + result.positional_arguments.insert( + {CLIPositionalArgumentKey{consumed_positional_args}, arg}); + consumed_positional_args++; + + return std::nullopt; + }; + + for (int i = 1; i < args.size(); i++) { + std::string arg = args.at(i); + + if (!arg.empty() && arg.at(0) == '-') { + tl::expected parsed_flag = + cli_parse_flag(cli, arg); + + if (parsed_flag.has_value()) { + result.flags.at(parsed_flag.value()) = true; + } + } else { + std::optional maybe_err_msg = parse_positional_arg(arg); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + } + + if (consumed_positional_args != cli.positional_arguments.size()) { + return tl::unexpected( + fmt::format("Not enough positional arguments: found {}, expected {}", + consumed_positional_args, + cli.positional_arguments.size())); + } + + return result; +} + +tl::expected + cli_parse(CLISpec const &cli, int argc, char const *const *argv) { + std::vector args = {argv, argv + argc}; + + return cli_parse(cli, args); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse_result.cc b/lib/utils/src/utils/cli/cli_parse_result.cc new file mode 100644 index 0000000000..6682a7a6eb --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse_result.cc @@ -0,0 +1,14 @@ +#include "utils/cli/cli_parse_result.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &result, CLIArgumentKey const &key) { + return result.flags.at(key.get()); +} + +std::string cli_get_argument(CLIParseResult const &result, + CLIArgumentKey const &key) { + return result.positional_arguments.at(key.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_spec.cc b/lib/utils/src/utils/cli/cli_spec.cc new file mode 100644 index 0000000000..ca51cfe57f --- /dev/null +++ b/lib/utils/src/utils/cli/cli_spec.cc @@ -0,0 +1,37 @@ +#include "utils/cli/cli_spec.h" +#include "utils/containers/count.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +CLISpec empty_cli_spec() { + return CLISpec{{}, {}}; +} + +std::vector cli_get_flag_keys(CLISpec const &cli) { + return transform(count(cli.flags.size()), + [](int idx) { return CLIFlagKey{idx}; }); +} + +CLIArgumentKey cli_add_help_flag(CLISpec &cli) { + CLIFlagSpec help_flag = + CLIFlagSpec{"help", 'h', "show this help message and exit"}; + return cli_add_flag(cli, help_flag); +} + +CLIArgumentKey cli_add_flag(CLISpec &cli, CLIFlagSpec const &flag_spec) { + cli.flags.push_back(flag_spec); + + return CLIArgumentKey{CLIFlagKey{int_from_size_t(cli.flags.size()) - 1}}; +} + +CLIArgumentKey + cli_add_positional_argument(CLISpec &cli, + CLIPositionalArgumentSpec const &arg) { + cli.positional_arguments.push_back(arg); + return CLIArgumentKey{CLIPositionalArgumentKey{ + int_from_size_t(cli.positional_arguments.size()) - 1}}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/as_vector.cc b/lib/utils/src/utils/containers/as_vector.cc deleted file mode 100644 index 9c7b63ca58..0000000000 --- a/lib/utils/src/utils/containers/as_vector.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/as_vector.h" diff --git a/lib/utils/src/utils/containers/enumerate_vector.cc b/lib/utils/src/utils/containers/enumerate_vector.cc new file mode 100644 index 0000000000..d4fd131af2 --- /dev/null +++ b/lib/utils/src/utils/containers/enumerate_vector.cc @@ -0,0 +1 @@ +#include "utils/containers/enumerate_vector.h" diff --git a/lib/utils/src/utils/containers/foldl1.cc b/lib/utils/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..c6cdd0eec9 --- /dev/null +++ b/lib/utils/src/utils/containers/foldl1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldl1.h" diff --git a/lib/utils/src/utils/containers/foldr1.cc b/lib/utils/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..9d00d81565 --- /dev/null +++ b/lib/utils/src/utils/containers/foldr1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldr1.h" diff --git a/lib/utils/src/utils/containers/get_element_counts.cc b/lib/utils/src/utils/containers/get_element_counts.cc index 9840ed34d8..ac8e289523 100644 --- a/lib/utils/src/utils/containers/get_element_counts.cc +++ b/lib/utils/src/utils/containers/get_element_counts.cc @@ -1,10 +1,10 @@ #include "utils/containers/get_element_counts.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { std::unordered_map get_element_counts(std::string const &s) { - return get_element_counts(as_vector(s)); + return get_element_counts(vector_of(s)); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/maximum.cc b/lib/utils/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..51d92cf951 --- /dev/null +++ b/lib/utils/src/utils/containers/maximum.cc @@ -0,0 +1 @@ +#include "utils/containers/maximum.h" diff --git a/lib/utils/src/utils/containers/multiset_union.cc b/lib/utils/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..a053d05fa6 --- /dev/null +++ b/lib/utils/src/utils/containers/multiset_union.cc @@ -0,0 +1 @@ +#include "utils/containers/multiset_union.h" diff --git a/lib/utils/src/utils/containers/require_no_duplicates.cc b/lib/utils/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..b1d21ad832 --- /dev/null +++ b/lib/utils/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1 @@ +#include "utils/containers/require_no_duplicates.h" diff --git a/lib/utils/src/utils/containers/set_of.cc b/lib/utils/src/utils/containers/set_of.cc new file mode 100644 index 0000000000..3a12ee539d --- /dev/null +++ b/lib/utils/src/utils/containers/set_of.cc @@ -0,0 +1 @@ +#include "utils/containers/set_of.h" diff --git a/lib/utils/src/utils/containers/to_uppercase.cc b/lib/utils/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..6c02b5a109 --- /dev/null +++ b/lib/utils/src/utils/containers/to_uppercase.cc @@ -0,0 +1,10 @@ +#include "utils/containers/to_uppercase.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::string to_uppercase(std::string const &s) { + return transform(s, [](char c) -> char { return std::toupper(c); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/vector_of.cc b/lib/utils/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..b997076511 --- /dev/null +++ b/lib/utils/src/utils/containers/vector_of.cc @@ -0,0 +1 @@ +#include "utils/containers/vector_of.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 323f444a22..6ed41daf43 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -219,10 +219,6 @@ std::unordered_set get_endpoints(UndirectedEdge const &e) { // return g.query_edges(MultiDiEdgeQuery::all()); // } -std::unordered_set get_edges(UndirectedGraphView const &g) { - return g.query_edges(undirected_edge_query_all()); -} - // std::unordered_set get_edges(OpenMultiDiGraphView const &g) // { // return g.query_edges(OpenMultiDiEdgeQuery::all()); diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..d17a84dd12 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + + std::unordered_set all_nodes = get_nodes(g); + query_set src_query = query_set{set_minus(all_nodes, ns)}; + + DataflowEdgeQuery query = DataflowEdgeQuery{ + src_query, + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 011d8b3ed9..8afe7da926 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -3,8 +3,13 @@ #include "utils/containers/extend.h" #include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/set.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" #include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/graph/digraph/algorithms/get_outgoing_edges.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" @@ -12,23 +17,35 @@ #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/node/algorithms.h" #include "utils/hash/unordered_set.h" +#include namespace FlexFlow { std::optional - get_cbc_decomposition(DiGraphView const &g) { + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &g, std::vector const &edge_order) { // implementation of the algorithm from https://doi.org/10.1145/800135.804393 // top left of page 8, second paragraph + std::queue edges_to_process; + for (DirectedEdge const &e : edge_order) { + edges_to_process.push(e); + } + std::unordered_set already_in_a_head = {}; std::unordered_set already_in_a_tail = {}; - std::unordered_set edges_to_process = get_edges(g); + + std::unordered_set already_processed = {}; CompleteBipartiteCompositeDecomposition result = CompleteBipartiteCompositeDecomposition{{}}; while (!edges_to_process.empty()) { - DirectedEdge e = get_first(edges_to_process); + DirectedEdge e = edges_to_process.front(); + edges_to_process.pop(); + if (contains(already_processed, e)) { + continue; + } std::unordered_set head = get_predecessors(g, e.dst); std::unordered_set tail = get_successors(g, e.src); @@ -39,6 +56,12 @@ std::optional std::unordered_set from_head_to_tail = g.query_edges(DirectedEdgeQuery{head, tail}); + + DiGraphView subgraph = get_subgraph(g, set_union(head, tail)); + if (!is_complete_bipartite_digraph(subgraph, head)) { + return std::nullopt; + } + if (set_union(values(get_outgoing_edges(g, head))) != from_head_to_tail) { return std::nullopt; } @@ -47,7 +70,7 @@ std::optional } result.subgraphs.insert(BipartiteComponent{head, tail}); - edges_to_process = set_minus(edges_to_process, from_head_to_tail); + already_processed = set_union(already_processed, from_head_to_tail); extend(already_in_a_head, head); extend(already_in_a_tail, tail); } @@ -58,4 +81,10 @@ std::optional return result; } +std::optional + get_cbc_decomposition(DiGraphView const &g) { + std::vector edge_order = vector_of(get_edges(g)); + return get_cbc_decomposition_with_edge_order_internal(g, edge_order); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc new file mode 100644 index 0000000000..2eab8371b2 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -0,0 +1,29 @@ +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/containers/get_first.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &g) { + return is_complete_bipartite_digraph(g, get_sources(g)); +} + +bool is_complete_bipartite_digraph(DiGraphView const &g, + std::unordered_set const &srcs) { + std::unordered_set sinks = set_minus(get_nodes(g), srcs); + + std::unordered_set edges = get_edges(g); + + std::unordered_set expected_edges; + for (Node const &src : srcs) { + for (Node const &sink : sinks) { + expected_edges.insert(DirectedEdge{src, sink}); + } + } + + return edges == expected_edges; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc new file mode 100644 index 0000000000..ad7830cc76 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc @@ -0,0 +1,32 @@ +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &g, + std::function const &get_node_label) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + auto get_node_name = [](Node const &n) { + return fmt::format("n{}", n.raw_uid); + }; + + for (Node const &n : get_nodes(g)) { + RecordFormatter rec; + rec << get_node_label(n); + dot.add_record_node(get_node_name(n), rec); + } + + for (DirectedEdge const &e : get_edges(g)) { + dot.add_edge(get_node_name(e.src), get_node_name(e.dst)); + } + + dot.close(); + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc new file mode 100644 index 0000000000..5c790abb8c --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc @@ -0,0 +1,13 @@ +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &g, DirectedEdge const &e) { + return !g.query_edges(DirectedEdgeQuery{ + query_set{e.src}, + query_set{e.dst}, + }) + .empty(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc index 2e570cbdf9..34cc7fcc6f 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/get_imm_dominators_map.h" -#include "utils/containers/as_vector.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" #include "utils/containers/generate_map.h" @@ -7,6 +6,7 @@ #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms/get_dominators_map.h" #include "utils/graph/node/algorithms.h" @@ -22,8 +22,8 @@ std::unordered_map> std::unordered_set n_dominators = node_to_its_dominators.at(n); n_dominators.erase(n); std::vector recursive_dominator_list = concat_vectors( - transform(as_vector(n_dominators), [&](Node const &dominator) { - return as_vector(node_to_its_dominators.at(dominator)); + transform(vector_of(n_dominators), [&](Node const &dominator) { + return vector_of(node_to_its_dominators.at(dominator)); })); std::unordered_map dominator_counts = get_element_counts(recursive_dominator_list); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..f19deb3046 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set get_subgraph_outgoing_edges( + DiGraphView const &g, std::unordered_set const &subgraph_nodes) { + std::unordered_set external_nodes = + set_minus(get_nodes(g), subgraph_nodes); + DirectedEdgeQuery query = DirectedEdgeQuery{query_set{subgraph_nodes}, + query_set{external_nodes}}; + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc new file mode 100644 index 0000000000..e860fb11b1 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &g, + std::unordered_set const &subgraph_nodes) { + std::unordered_set successors = + transform(get_subgraph_outgoing_edges(g, subgraph_nodes), + [](DirectedEdge const &e) { return e.dst; }); + + return successors; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..3efea1c138 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,51 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &g) { + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + std::unordered_set edges = get_edges(g); + + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } + + DiGraph result = materialize_digraph_view(g); + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; + result.add_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(j)}); + } + } + } + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index 10ffe4fc33..97a2439263 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,7 +1,12 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_closure.h" #include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -24,29 +29,60 @@ DirectedEdgeMaskView *DirectedEdgeMaskView::clone() const { } DiGraphView transitive_reduction(DiGraphView const &g) { - std::unordered_set edge_mask = get_edges(g); + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + // + // transitive_closure inlined to avoid any drifts in node numbering + // between transitive_closure and transitive_reduction + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } - while (true) { - std::unordered_set new_edge_mask = edge_mask; - for (DirectedEdge const &e1 : edge_mask) { - for (DirectedEdge const &e2 : edge_mask) { - if (e1.dst == e2.src && e1 != e2) { - DirectedEdge trans_edge = DirectedEdge{e1.src, e2.dst}; - if (contains(new_edge_mask, trans_edge)) { - new_edge_mask.erase(trans_edge); + // compute transitive closure + // see https://cs.winona.edu/lin/cs440/ch08-2.pdf slide 8-8 + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; } } } } + } - if (new_edge_mask == edge_mask) { - break; - } else { - edge_mask = new_edge_mask; + DiGraph result = materialize_digraph_view(g); + // compute transitive reduction + // see https://stackoverflow.com/a/6702198 + std::unordered_set edge_mask = get_edges(g); + for (int j = 0; j < num_nodes; j++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, j)) { + for (int k = 0; k < num_nodes; k++) { + if (has_edge(j, k)) { + has_edge(i, k) = false; + result.remove_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(k)}); + } + } + } } } - return DiGraphView::create(g, edge_mask); + return result; } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc index 34a8eff503..68ef12c49e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc @@ -38,11 +38,7 @@ void AdjacencyDiGraph::add_edge(DirectedEdge const &e) { } void AdjacencyDiGraph::remove_edge(DirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.src); - auto iter = m.find(e.dst); - if (iter != m.end()) { - m.erase(iter); - } + this->adjacency.at(e.src).erase(e.dst); } std::unordered_set diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc new file mode 100644 index 0000000000..6f6722f635 --- /dev/null +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -0,0 +1,58 @@ +#include "utils/graph/instances/unordered_set_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph() {} + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph( + NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &edges) + : node_source(node_source), nodes(nodes), edges(edges) {} + +Node UnorderedSetUndirectedGraph::add_node() { + Node new_node = this->node_source.new_node(); + this->nodes.insert(new_node); + return new_node; +} + +void UnorderedSetUndirectedGraph::add_node_unsafe(Node const &n) { + this->nodes.insert(n); +} + +void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { + this->nodes.erase(n); +} + +void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { + assert(contains(this->nodes, e.bigger)); + assert(contains(this->nodes, e.smaller)); + this->edges.insert(e); +} + +void UnorderedSetUndirectedGraph::remove_edge(UndirectedEdge const &e) { + this->edges.erase(e); +} + +std::unordered_set + UnorderedSetUndirectedGraph::query_nodes(NodeQuery const &q) const { + return apply_node_query(q, this->nodes); +} + +std::unordered_set UnorderedSetUndirectedGraph::query_edges( + UndirectedEdgeQuery const &q) const { + return filter(this->edges, + [&](UndirectedEdge const &e) { return matches_edge(q, e); }); +} + +UnorderedSetUndirectedGraph *UnorderedSetUndirectedGraph::clone() const { + return new UnorderedSetUndirectedGraph{ + this->node_source, + this->nodes, + this->edges, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc index 47096d492c..53497a715d 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc @@ -1,7 +1,7 @@ #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_counts.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" @@ -10,7 +10,7 @@ namespace FlexFlow { std::unordered_map get_edge_counts(MultiDiGraphView const &g) { return get_element_counts( - transform(as_vector(get_edges(g)), + transform(vector_of(get_edges(g)), [&](MultiDiEdge const &e) { return get_directed_edge(g, e); })); } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index d95a9b9565..1dd5353301 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -2,12 +2,12 @@ #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/get_first.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" @@ -201,7 +201,7 @@ std::unordered_set OpenDataflowGraphView const &dst) { std::unordered_set result; - std::vector src_sink_nodes = as_vector(get_sinks(src)); + std::vector src_sink_nodes = vector_of(get_sinks(src)); std::unordered_set dst_sink_nodes = get_sinks(dst); if (src_sink_nodes.size() != dst_sink_nodes.size()) { @@ -209,7 +209,7 @@ std::unordered_set } std::vector src_unused_graph_inputs = - as_vector(get_unused_open_dataflow_graph_inputs(src)); + vector_of(get_unused_open_dataflow_graph_inputs(src)); std::unordered_set dst_unused_graph_inputs = get_unused_open_dataflow_graph_inputs(dst); diff --git a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc deleted file mode 100644 index 6384bd9159..0000000000 --- a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" -#include "utils/containers/extend.h" - -namespace FlexFlow { - -struct FlattenAST { - void add_flattened_child_to_parent( - IntermediateSpDecompositionTree &parent, - std::variant const &child) { - if (std::holds_alternative(child)) { - parent.children.push_back(child); - return; - } - - IntermediateSpDecompositionTree child_node = - std::get(child); - - if (parent.type == child_node.type) { - extend(parent.children, child_node.children); - } else { - parent.children.push_back(child); - } - } - - std::variant - operator()(IntermediateSpDecompositionTree const &ast_node) { - IntermediateSpDecompositionTree result(ast_node.type, {}); - for (std::variant const &child : - ast_node.children) { - std::variant flattened_child = - flatten_ast(child); - add_flattened_child_to_parent(result, flattened_child); - } - return result; - } - - std::variant - operator()(Node const &ast_node) { - return ast_node; - } -}; - -std::variant flatten_ast( - std::variant const &ast) { - return std::visit(FlattenAST{}, ast); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc new file mode 100644 index 0000000000..18d1f922c6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -0,0 +1,43 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" + +namespace FlexFlow { + +BinarySPDecompositionTree + make_series_split(BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{ + make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), + }; +} + +BinarySPDecompositionTree + make_parallel_split(BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{ + make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + }; +} + +BinarySPDecompositionTree make_leaf_node(Node const &n) { + return BinarySPDecompositionTree{ + make_generic_binary_sp_leaf(n), + }; +} + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tt) { + return is_binary_sp_tree_left_associative(tt.raw_tree); +} + +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tt) { + return is_binary_sp_tree_right_associative(tt.raw_tree); +} + +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { + return get_leaves(tt.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc new file mode 100644 index 0000000000..4cd7206408 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc new file mode 100644 index 0000000000..3a4dbad8ec --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc new file mode 100644 index 0000000000..4ee18af5be --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..71b67acc54 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc new file mode 100644 index 0000000000..227e5bd79c --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc new file mode 100644 index 0000000000..1618128226 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..05ec6b5925 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc new file mode 100644 index 0000000000..f168ba1e2f --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc new file mode 100644 index 0000000000..75c472c435 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc new file mode 100644 index 0000000000..3da024743c --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..8fe9397003 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..d202f55964 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc new file mode 100644 index 0000000000..b569ff9265 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc new file mode 100644 index 0000000000..fb1532b3ef --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc new file mode 100644 index 0000000000..3fee45fcf5 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..cabd66cff7 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc new file mode 100644 index 0000000000..25409333f2 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..02e541b7e4 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,75 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldl1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function( + std::variant const &)> + from_series_child; + std::function( + std::variant const &)> + from_parallel_child; + + auto from_node = [](Node const &n) -> GenericBinarySPDecompositionTree { + return GenericBinarySPDecompositionTree{n}; + }; + + auto from_series = + [&](SeriesSplit const &s) -> GenericBinarySPDecompositionTree { + std::vector> children = + transform(s.children, from_series_child); + return foldl1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{accum, x}, + }; + }); + }; + + auto from_parallel = + [&](ParallelSplit const &s) -> GenericBinarySPDecompositionTree { + std::vector> children = + transform(vector_of(s.children), from_parallel_child); + return foldl1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{accum, x}}; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return BinarySPDecompositionTree{ + nary.visit>(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..3b8affd16d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,12 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &binary) { + return to_final_ast(from_binary_sp_tree(binary)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..673a4118a6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,72 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldr1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function( + std::variant const &)> + from_series_child; + std::function( + std::variant const &)> + from_parallel_child; + + auto from_node = [](Node const &n) { + return GenericBinarySPDecompositionTree{n}; + }; + + auto from_series = [&](SeriesSplit const &s) { + std::vector> children = + transform(s.children, from_series_child); + return foldr1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{x, accum}}; + }); + }; + + auto from_parallel = [&](ParallelSplit const &s) { + std::vector> children = + transform(vector_of(s.children), from_parallel_child); + return foldr1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{x, accum}}; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return BinarySPDecompositionTree{ + nary.visit>(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc similarity index 62% rename from lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 632f5245db..ab231f256c 100644 --- a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -1,23 +1,28 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/parallel_reduction.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_reduction.h" namespace FlexFlow { -std::optional - get_serial_parallel_decomposition(DiGraphView const &g) { +std::optional + get_series_parallel_decomposition(DiGraphView const &g) { + + DiGraphView transitively_reduced = transitive_reduction(g); InverseLineGraphResult inverse_line_graph_result = ({ std::optional maybe_line_graph = - get_inverse_line_graph(g); + get_inverse_line_graph(transitively_reduced); if (!maybe_line_graph.has_value()) { return std::nullopt; } @@ -27,14 +32,11 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); - std::unordered_map> - ttsp_edge_to_sp_tree = map_values( - inverse_line_graph_result.inverse_edge_to_line_node_bidict - .as_unordered_map(), - [](Node const &n) { - return std::variant{n}; - }); + std::unordered_map + ttsp_edge_to_sp_tree = + map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return make_leaf_node(n); }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -44,11 +46,8 @@ std::optional ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = make_parallel_split( + ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -63,11 +62,8 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::SERIAL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = make_series_split( + ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -83,7 +79,7 @@ std::optional MultiDiEdge e = get_only(get_edges(ttsp)); if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { - return to_final_ast(ttsp_edge_to_sp_tree.at(e)); + return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); } } } diff --git a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc similarity index 79% rename from lib/utils/src/utils/graph/serial_parallel/graph_generation.cc rename to lib/utils/src/utils/graph/series_parallel/graph_generation.cc index 4c9eb9d3ef..7070d04c4a 100644 --- a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc +++ b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/graph_generation.h" +#include "utils/graph/series_parallel/graph_generation.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -12,7 +12,7 @@ void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { } } -void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { +void series_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { // TODO(@lockshaw): This function signature is impossible to implement in // general, as there is no guarantee that the graph view ext actually has // source nodes with inputs Either the signature should be changed, or an @@ -22,11 +22,11 @@ void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { NOT_IMPLEMENTED(); } -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2) { DataflowGraph g = DataflowGraph::create_copy_of(g1); - serial_extend_unsafe(g, g2); + series_extend_unsafe(g, g2); return g; } @@ -39,8 +39,8 @@ DataflowGraph parallel_composition(DataflowGraphView const &g1, } DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition) { - // TODO(@lockshaw): see existing concerns about serial_extend_unsafe + SeriesParallelDecomposition const &sp_decomposition) { + // TODO(@lockshaw): see existing concerns about series_extend_unsafe NOT_IMPLEMENTED(); } diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc new file mode 100644 index 0000000000..48c936ec39 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -0,0 +1,84 @@ +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/containers/extend.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +struct FlattenAST { + void add_flattened_child_to_parent( + IntermediateSpDecompositionTree &parent, + std::variant const &child) { + if (std::holds_alternative(child)) { + parent.children.push_back(child); + return; + } + + IntermediateSpDecompositionTree child_node = + std::get(child); + + if (parent.type == child_node.type) { + extend(parent.children, child_node.children); + } else { + parent.children.push_back(child); + } + } + + std::variant + operator()(IntermediateSpDecompositionTree const &ast_node) { + IntermediateSpDecompositionTree result(ast_node.type, {}); + for (std::variant const &child : + ast_node.children) { + std::variant flattened_child = + flatten_ast(child); + add_flattened_child_to_parent(result, flattened_child); + } + return result; + } + + std::variant + operator()(Node const &ast_node) { + return ast_node; + } +}; + +std::variant flatten_ast( + std::variant const &ast) { + return std::visit(FlattenAST{}, ast); +} + +std::variant + from_binary_sp_tree(GenericBinarySPDecompositionTree const &binary) { + return visit>( + binary, + overload{ + [](Node const &n) { return n; }, + [](GenericBinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(get_left_child(s)), + from_binary_sp_tree(get_right_child(s)), + }, + }; + }, + [](GenericBinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(get_left_child(p)), + from_binary_sp_tree(get_right_child(p)), + }, + }; + }, + }); +} + +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &binary) { + return from_binary_sp_tree(binary.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 93% rename from lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 30aa10edd7..12a6630bf0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 52% rename from lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 666bf40f10..e697533054 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,18 +1,20 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" namespace FlexFlow { struct ToFinalAST { - std::variant + std::variant operator()(IntermediateSpDecompositionTree const &node) { - if (node.type == SplitType::SERIAL) { - return SerialSplit{transform( + if (node.type == SplitType::SERIES) { + return SeriesSplit{transform( node.children, [](std::variant const &s) { return narrow>( @@ -20,54 +22,55 @@ struct ToFinalAST { .value(); })}; } else { - return ParallelSplit{unordered_set_of(transform( + return ParallelSplit{unordered_multiset_of(transform( node.children, [](std::variant const &s) { - return narrow>( + return narrow>( internal_to_final_ast(s)) .value(); }))}; } } - std::variant operator()(Node const &node) { + std::variant operator()(Node const &node) { return node; } }; -std::variant internal_to_final_ast( +std::variant internal_to_final_ast( std::variant const &ast) { return std::visit(ToFinalAST{}, flatten_ast(ast)); } -SerialParallelDecomposition to_final_ast( +SeriesParallelDecomposition to_final_ast( std::variant const &ast) { - return std::visit([](auto &&x) { return SerialParallelDecomposition{x}; }, + return std::visit([](auto &&x) { return SeriesParallelDecomposition{x}; }, internal_to_final_ast(ast)); } -std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { - return sp.visit>( +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp) { + return sp.visit>( [](auto &&t) { return get_nodes(t); }); } -std::unordered_set get_nodes(SerialSplit const &serial) { - return set_union(transform( +std::unordered_multiset get_nodes(SeriesSplit const &serial) { + return multiset_union(transform( serial.children, [](std::variant const &child) - -> std::unordered_set { + -> std::unordered_multiset { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(ParallelSplit const ¶llel) { - return set_union(transform( - parallel.children, [](std::variant const &child) { +std::unordered_multiset get_nodes(ParallelSplit const ¶llel) { + return multiset_union(transform( + vector_of(parallel.children), + [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(Node const &node) { +std::unordered_multiset get_nodes(Node const &node) { return {node}; } diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc similarity index 65% rename from lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc index 8fa42d4b22..0e04a4f904 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc @@ -1,47 +1,47 @@ -#include "utils/graph/serial_parallel/serial_parallel_splits.h" -#include "utils/fmt/unordered_set.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" +#include "utils/fmt/unordered_multiset.h" #include "utils/fmt/variant.h" #include "utils/fmt/vector.h" #include "utils/hash-utils.h" -#include "utils/hash/unordered_set.h" +#include "utils/hash/unordered_multiset.h" #include "utils/hash/vector.h" namespace FlexFlow { -SerialSplit::SerialSplit( +SeriesSplit::SeriesSplit( std::vector> const &children) : children(children) {} -SerialSplit::SerialSplit( +SeriesSplit::SeriesSplit( std::initializer_list> const &children) : children(children) {} -bool SerialSplit::operator==(SerialSplit const &other) const { +bool SeriesSplit::operator==(SeriesSplit const &other) const { return this->tie() == other.tie(); } -bool SerialSplit::operator!=(SerialSplit const &other) const { +bool SeriesSplit::operator!=(SeriesSplit const &other) const { return this->tie() != other.tie(); } -SerialSplit::Tie SerialSplit::tie() const { +SeriesSplit::Tie SeriesSplit::tie() const { return std::tie(this->children); } -std::string format_as(SerialSplit const &split) { - return fmt::format("", split.children); +std::string format_as(SeriesSplit const &split) { + return fmt::format("", split.children); } -std::ostream &operator<<(std::ostream &s, SerialSplit const &split) { +std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { return s << fmt::to_string(split); } ParallelSplit::ParallelSplit( - std::unordered_set> const &children) + std::unordered_multiset> const &children) : children(children) {} ParallelSplit::ParallelSplit( - std::initializer_list> const &children) + std::initializer_list> const &children) : children(children) {} bool ParallelSplit::operator==(ParallelSplit const &other) const { @@ -68,8 +68,8 @@ std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { namespace std { -size_t hash<::FlexFlow::SerialSplit>::operator()( - ::FlexFlow::SerialSplit const &s) const { +size_t hash<::FlexFlow::SeriesSplit>::operator()( + ::FlexFlow::SeriesSplit const &s) const { size_t result = 0; ::FlexFlow::hash_combine(result, s.children); return result; diff --git a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc similarity index 97% rename from lib/utils/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/series_reduction.cc index e26f460e0e..7300c93fb0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/require_same.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc new file mode 100644 index 0000000000..8ae825c1ab --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc @@ -0,0 +1,10 @@ +#include "utils/graph/undirected/algorithms/get_edges.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &g) { + return g.query_edges(undirected_edge_query_all()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc new file mode 100644 index 0000000000..3c05b9d5d5 --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -0,0 +1,19 @@ +#include "utils/graph/undirected/algorithms/get_neighboring_nodes.h" +#include "utils/containers/vector_of.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, + Node const &n) { + std::unordered_set edges = + g.query_edges(UndirectedEdgeQuery{query_set{n}}); + + std::unordered_set result = + set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { + return std::unordered_set{e.bigger, e.smaller}; + })); + result.erase(n); + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 5c41eef7da..3cccf1c6eb 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -6,6 +6,10 @@ UndirectedEdgeQuery undirected_edge_query_all() { return UndirectedEdgeQuery{matchall()}; } +bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { + return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); +} + UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, UndirectedEdgeQuery const &rhs) { return UndirectedEdgeQuery{ diff --git a/lib/utils/src/utils/hash/multiset.cc b/lib/utils/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..d84ca7d614 --- /dev/null +++ b/lib/utils/src/utils/hash/multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/multiset.h" diff --git a/lib/utils/src/utils/hash/unordered_multiset.cc b/lib/utils/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..7f6f73f428 --- /dev/null +++ b/lib/utils/src/utils/hash/unordered_multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/unordered_multiset.h" diff --git a/lib/utils/src/utils/json/check_is_jsonable.cc b/lib/utils/src/utils/json/check_is_jsonable.cc new file mode 100644 index 0000000000..1e78fdb21f --- /dev/null +++ b/lib/utils/src/utils/json/check_is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_jsonable.h" diff --git a/lib/utils/src/utils/json/is_json_deserializable.cc b/lib/utils/src/utils/json/is_json_deserializable.cc new file mode 100644 index 0000000000..17df41433d --- /dev/null +++ b/lib/utils/src/utils/json/is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/is_json_serializable.cc b/lib/utils/src/utils/json/is_json_serializable.cc new file mode 100644 index 0000000000..883ee9f51a --- /dev/null +++ b/lib/utils/src/utils/json/is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_serializable.h" diff --git a/lib/utils/src/utils/json/is_jsonable.cc b/lib/utils/src/utils/json/is_jsonable.cc new file mode 100644 index 0000000000..3f819f8556 --- /dev/null +++ b/lib/utils/src/utils/json/is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/is_jsonable.h" diff --git a/lib/utils/src/utils/json/optional.cc b/lib/utils/src/utils/json/optional.cc new file mode 100644 index 0000000000..c8f0fd2e3c --- /dev/null +++ b/lib/utils/src/utils/json/optional.cc @@ -0,0 +1 @@ +#include "utils/json/optional.h" diff --git a/lib/utils/src/utils/rapidcheck/optional.cc b/lib/utils/src/utils/rapidcheck/optional.cc new file mode 100644 index 0000000000..6d62532e7e --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/optional.cc @@ -0,0 +1 @@ +#include "utils/rapidcheck/optional.h" diff --git a/lib/utils/test/common/include/test/utils/all.h b/lib/utils/test/common/include/test/utils/all.h deleted file mode 100644 index ced1c9ce38..0000000000 --- a/lib/utils/test/common/include/test/utils/all.h +++ /dev/null @@ -1,2 +0,0 @@ -#include "test/utils/doctest.h" -#include "test/utils/rapidcheck.h" diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h similarity index 100% rename from lib/utils/test/common/include/test/utils/doctest.h rename to lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h new file mode 100644 index 0000000000..8333ac4777 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H + +#include "utils/fmt/expected.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h new file mode 100644 index 0000000000..d20dbe6943 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H + +#include "utils/fmt/map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h new file mode 100644 index 0000000000..b26eee28ba --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H + +#include "utils/fmt/multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h new file mode 100644 index 0000000000..519cde7d74 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H + +#include "utils/fmt/optional.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::optional const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h new file mode 100644 index 0000000000..db0ed24f13 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H + +#include "utils/fmt/pair.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::pair const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h new file mode 100644 index 0000000000..3dd386645c --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H + +#include "utils/fmt/set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h new file mode 100644 index 0000000000..4fd5d15009 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H + +#include "utils/fmt/unordered_map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h new file mode 100644 index 0000000000..94dae42239 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H + +#include "utils/fmt/unordered_multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h new file mode 100644 index 0000000000..441590365d --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H + +#include "utils/fmt/unordered_set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h new file mode 100644 index 0000000000..c30862274a --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H + +#include "utils/fmt/variant.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::variant const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h new file mode 100644 index 0000000000..56198a7558 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H + +#include "utils/fmt/vector.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::vector const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/src/common.cc b/lib/utils/test/common/src/common.cc deleted file mode 100644 index 51e981b1f5..0000000000 --- a/lib/utils/test/common/src/common.cc +++ /dev/null @@ -1 +0,0 @@ -#include "test/utils/all.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc new file mode 100644 index 0000000000..1cff2195db --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/expected.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc new file mode 100644 index 0000000000..976e65cfca --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc new file mode 100644 index 0000000000..9c5b2f4d1e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc new file mode 100644 index 0000000000..8a3f7f158e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/optional.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc new file mode 100644 index 0000000000..106fb1c900 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/pair.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc new file mode 100644 index 0000000000..9ec70698bc --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc new file mode 100644 index 0000000000..b893e632ed --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc new file mode 100644 index 0000000000..55d2e69056 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc new file mode 100644 index 0000000000..13ad811e63 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc new file mode 100644 index 0000000000..b6cc4f54e4 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/variant.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc new file mode 100644 index 0000000000..0102cd86da --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/vector.h" diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index a1dd75504e..44f602f3bc 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -1,10 +1,10 @@ -#include "test/utils/doctest.h" #include "utils/graph/algorithms.h" #include "utils/graph/construction.h" #include "utils/graph/hashmap_undirected_graph.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/undirected.h" #include +#include #include #include #include diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index af7792dc6d..dca500ced5 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/containers.h" +#include #include #include #include @@ -275,9 +275,9 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == std::vector({2, 4, 6})); } - TEST_CASE("as_vector") { + TEST_CASE("vector_of") { std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); + std::vector result = vector_of(s); CHECK(result == std::vector({3, 2, 1})); } diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/test_deduplicated_priority_queue.cc index 66cfd395bc..048e95acb7 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/test_deduplicated_priority_queue.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/deduplicated_priority_queue.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index 80fcf87d6b..65037be3dd 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/disjoint_set.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/test_dot_file.cc index ed4c32bb1c..e409572511 100644 --- a/lib/utils/test/src/test_dot_file.cc +++ b/lib/utils/test/src/test_dot_file.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/dot_file.h" +#include #include TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/test_format.cc index eeed2eae81..f0d396a123 100644 --- a/lib/utils/test/src/test_format.cc +++ b/lib/utils/test/src/test_format.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/record_formatter.h" +#include std::string formatRecord(RecordFormatter const &formatter) { std::ostringstream oss; diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc index b38c43fe30..decf405e7a 100644 --- a/lib/utils/test/src/test_hash.cc +++ b/lib/utils/test/src/test_hash.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/hash-utils.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 90e1bb2187..cc7ac1de32 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -1,7 +1,7 @@ -#include "test/utils/doctest.h" #include "utils/graph/adjacency_multidigraph.h" #include "utils/graph/multidiedge.h" #include "utils/graph/multidigraph_interfaces.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc index 88a566a198..2b816eea4f 100644 --- a/lib/utils/test/src/test_random_utils.cc +++ b/lib/utils/test/src/test_random_utils.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "utils/random_utils.h" #include +#include void checkProbabilities(std::vector const &counts, int numIterations, diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/test_sequence.cc index ee72febe05..a758476fd9 100644 --- a/lib/utils/test/src/test_sequence.cc +++ b/lib/utils/test/src/test_sequence.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/sequence.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/test_stack_map.cc index 21c1b07d1b..f117820c5d 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/test_stack_map.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/stack_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/test_stack_string.cc index a044f85fe3..b89e3277cd 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/test_stack_string.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/stack_string.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 1af43b6993..577e61092c 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/stack_vector.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/test_tuple.cc index 31308dec2c..96171510a7 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/test_tuple.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/tuple.h" +#include #include #include diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc index b2d8aea848..e7ce12346a 100644 --- a/lib/utils/test/src/test_type_index.cc +++ b/lib/utils/test/src/test_type_index.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/type_index.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index 33b102bd3b..ea519478d3 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -1,7 +1,8 @@ -#include "test/utils/all.h" +#include "test/utils/rapidcheck.h" #include "test/utils/rapidcheck/visitable.h" #include "utils/graph/hashmap_undirected_graph.h" #include "utils/graph/undirected.h" +#include /* namespace rc { */ diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 98b28a48e9..0bd01b8dfe 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/variant.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/test_vector.cc index 4bdc724dd8..c6eb0828b8 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/test_vector.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/vector.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") { diff --git a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc index 6e3ac8c155..b5a373e5c9 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -1,7 +1,7 @@ #include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 5c2ffd5bba..fed655013f 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -1,6 +1,8 @@ #include "utils/bidict/bidict.h" -#include "test/utils/doctest.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/check_without_stringify.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/vector.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc index 2eb8f869f9..49fed81b29 100644 --- a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc @@ -1,6 +1,6 @@ #include "utils/bidict/try_merge_nondisjoint_bidicts.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/cli/cli_get_help_message.cc b/lib/utils/test/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..b3ee4d3318 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,519 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/join_strings.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_get_help_message(std::string, CLISpec)") { + std::string program_name = "prog_name"; + + SUBCASE("no flags or positional arguments") { + CLISpec cli = CLISpec{ + {}, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name\n"); + + CHECK(result == correct); + } + + SUBCASE("no flags") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "pos-arg-1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name pos-arg-1\n" + "\n" + "positional arguments:\n" + " pos-arg-1\n"); + + CHECK(result == correct); + } + + SUBCASE("no positional arguments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag-1", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag-1\n"); + + CHECK(result == correct); + } + + SUBCASE("flag formatting") { + SUBCASE("flag with shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flag without shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag]\n" + "\n" + "options:\n" + " --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flags are displayed in provided order") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag2] [--flag1]\n" + "\n" + "options:\n" + " --flag2\n" + " --flag1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("positional argument formatting") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg\n" + "\n" + "positional arguments:\n" + " posarg\n"); + + CHECK(result == correct); + } + + SUBCASE("with choices") { + SUBCASE("choices are not empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {red,blue,green}\n" + "\n" + "positional arguments:\n" + " {red,blue,green}\n"); + + CHECK(result == correct); + } + + SUBCASE("choices are empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {}\n" + "\n" + "positional arguments:\n" + " {}\n"); + + CHECK(result == correct); + } + } + + SUBCASE("are displayed in provided order") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg2 posarg1\n" + "\n" + "positional arguments:\n" + " posarg2\n" + " posarg1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("flag and positional argument alignment") { + SUBCASE("flags are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + CLIFlagSpec{ + "flag2-is-long", + std::nullopt, + "flag2-is-long description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "help text for posarg", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] [--flag2-is-long] posarg\n" + "\n" + "positional arguments:\n" + " posarg help text for posarg\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n" + " --flag2-is-long flag2-is-long description\n"); + + CHECK(result == correct); + } + + SUBCASE("pos args are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1-is-very-long", + std::nullopt, + "help text for posarg1-is-very-long", + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + "help text for posarg2", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] posarg1-is-very-long posarg2\n" + "\n" + "positional arguments:\n" + " posarg1-is-very-long help text for posarg1-is-very-long\n" + " posarg2 help text for posarg2\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n"); + + CHECK(result == correct); + } + + SUBCASE("line break behavior") { + SUBCASE("line breaks max out other argument alignments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + "flag help text", + }, + }, + { + CLIPositionalArgumentSpec{ + "abcdefghijklmnopqrstuvwxyz0123456789", + std::nullopt, + "long arg help text", + }, + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "posarg help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f] " + "abcdefghijklmnopqrstuvwxyz0123456789 posarg\n" + "\n" + "positional arguments:\n" + " abcdefghijklmnopqrstuvwxyz0123456789\n" + " long arg help text\n" + " posarg posarg help text\n" + "\n" + "options:\n" + " -f, --flag flag help text\n"); + + CHECK(result == correct); + } + SUBCASE("positional argument line break behavior") { + SUBCASE("positional arguments cause a line break at or above " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 22); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE("positional arguments do not cause a line break below " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 21); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + } + } + + SUBCASE("flag line break behavior") { + SUBCASE("flags cause a line break at or above formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 21); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbbb\n" + " flag description\n"); + + CHECK(result == correct); + } + + SUBCASE("flags do not cause a line break below formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 20); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbb flag description\n"); + + CHECK(result == correct); + } + } + + SUBCASE("choice line breakpoint formatting") { + SUBCASE( + "choices cause a line break at or above formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "fffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 21); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,fffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,fffffffff}\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE( + "choices do not cause a line break below formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "ffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 20); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,ffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,ffffffff} help text\n"); + + CHECK(result == correct); + } + } + } + } + } +} diff --git a/lib/utils/test/src/utils/cli/cli_parse.cc b/lib/utils/test/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..40dea86ae0 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_parse.cc @@ -0,0 +1,477 @@ +#include "utils/cli/cli_parse.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_parse_flag(CLISpec, std::string)") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("correctly parses short flag") { + std::string input = "-2"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag2; + + CHECK(result == correct); + } + + SUBCASE("correctly parses long flag") { + std::string input = "--flag1"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag1; + + CHECK(result == correct); + } + + SUBCASE("fails on unknown flag") { + std::string input = "--not-real"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = + tl::unexpected("Encountered unknown flag --not-real"); + + CHECK(result == correct); + } + + SUBCASE("fails on non-flag") { + std::string input = "-flag1"; + + std::optional result = + optional_from_expected(cli_parse_flag(cli, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + TEST_CASE("cli_parse(CLISpec, std::vector)") { + SUBCASE("works even if cli is empty") { + CLISpec cli = CLISpec{{}, {}}; + std::vector inputs = {"prog_name"}; + + tl::expected result = cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, {}}; + + CHECK(result == correct); + } + + SUBCASE("flag parsing") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("parses flags in any order") { + std::vector inputs = {"prog_name", "-2", "--flag1"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if some are not present") { + std::vector inputs = {"prog_name", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if none are present") { + std::vector inputs = {"prog_name"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, false}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine even if the program name is a flag") { + std::vector inputs = {"--flag1", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + } + + SUBCASE("positional argument parsing") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("can parse multiple positional arguments") { + std::vector inputs = {"prog_name", "hello", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("requires all positional arguments to be present") { + std::vector inputs = {"prog_name", "hello"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + + SUBCASE("requires no extra positional arguments to be present") { + std::vector inputs = { + "prog_name", "hello", "there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + tl::unexpected("Too many positional arguments: expected 2"); + + CHECK(result == correct); + } + + SUBCASE("allows arguments to contain spaces") { + std::vector inputs = { + "prog_name", "hello there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello there"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("allows arguments to be empty") { + std::vector inputs = {"prog_name", "hello", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, ""}, + }}; + + CHECK(result == correct); + } + } + + SUBCASE("with choices") { + SUBCASE("choices is non-empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg = CLIPositionalArgumentKey{0}; + + SUBCASE( + "succeeds if a positional argument is set to a valid choice") { + std::vector inputs = {"prog_name", "blue"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + {}, + { + {key_posarg, "red"}, + }, + }; + } + + SUBCASE( + "fails if a positional argument is set to an invalid choice") { + std::vector inputs = {"prog_name", " red"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \" red\""); + + CHECK(result == correct); + } + } + + SUBCASE("if choices is empty, rejects everything") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::vector inputs = {"prog_name", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \"\""); + + CHECK(result == correct); + } + } + } + + SUBCASE("correctly differentiates mixed arguments/flags") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("works if flags are before positional arguments") { + std::vector inputs = { + "prog_name", "-f", "--flag3", "red", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("works if flags are interspersed") { + std::vector inputs = { + "prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("detects if posargs are missing instead of treating flags as " + "posarg values") { + std::vector inputs = {"prog_name", "-f", "red", "--flag2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + } + } + + TEST_CASE("cli_parse(CLISpec, int argc, char const * const *argv)") { + // most cases are checked in the other overload, + // i.e., cli_parse(CLISpec, std::vector), + // so here we just throw in a single check to make sure + // nothing has unexpectedly gone wrong + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + int argc = 5; + char const *argv[] = {"prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, argc, argv); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/contains_key.cc b/lib/utils/test/src/utils/containers/contains_key.cc index acc6551cd4..da099113a6 100644 --- a/lib/utils/test/src/utils/containers/contains_key.cc +++ b/lib/utils/test/src/utils/containers/contains_key.cc @@ -1,8 +1,11 @@ #include "utils/containers/contains_key.h" -#include "test/utils/doctest.h" +#include #include +#include #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains_key(std::unordered_map, K)") { std::unordered_map m = { diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 2be5f1ef93..c6ce9942e9 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,8 +1,12 @@ #include "utils/containers/enumerate.h" -#include "utils/containers/as_vector.h" -#include "utils/fmt/map.h" -#include "utils/fmt/pair.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/containers/keys.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" #include #include @@ -25,7 +29,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("check iteration order") { std::vector> iterated_result = - as_vector(result); + vector_of(result); std::vector> correct_iteration_order = { {0, "zero"}, {1, "one"}, @@ -46,5 +50,17 @@ TEST_SUITE(FF_TEST_SUITE) { {2, "two"}, {3, "three"}, }; + + std::map result = enumerate(input); + + std::unordered_set result_keys = keys(correct); + std::unordered_set result_values = + unordered_set_of(values(correct)); + + std::unordered_set correct_keys = {0, 1, 2, 3}; + std::unordered_set correct_values = input; + + CHECK(result_keys == correct_keys); + CHECK(result_values == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/extend.cc b/lib/utils/test/src/utils/containers/extend.cc index e0d156a3fc..ef2a67725c 100644 --- a/lib/utils/test/src/utils/containers/extend.cc +++ b/lib/utils/test/src/utils/containers/extend.cc @@ -1,6 +1,6 @@ #include "utils/containers/extend.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filter.cc b/lib/utils/test/src/utils/containers/filter.cc index da459094ef..770ad40375 100644 --- a/lib/utils/test/src/utils/containers/filter.cc +++ b/lib/utils/test/src/utils/containers/filter.cc @@ -1,10 +1,10 @@ #include "utils/containers/filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_keys.cc b/lib/utils/test/src/utils/containers/filtermap_keys.cc index 758264627b..582e94392b 100644 --- a/lib/utils/test/src/utils/containers/filtermap_keys.cc +++ b/lib/utils/test/src/utils/containers/filtermap_keys.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_keys.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_values.cc b/lib/utils/test/src/utils/containers/filtermap_values.cc index d2b6ddd220..8db6d6a964 100644 --- a/lib/utils/test/src/utils/containers/filtermap_values.cc +++ b/lib/utils/test/src/utils/containers/filtermap_values.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_values.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtrans.cc b/lib/utils/test/src/utils/containers/filtrans.cc index b8bb832b06..cd1c2f896c 100644 --- a/lib/utils/test/src/utils/containers/filtrans.cc +++ b/lib/utils/test/src/utils/containers/filtrans.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtrans.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/foldl1.cc b/lib/utils/test/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..597aa5e109 --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldl1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldl1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldl1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldl1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"a s", "tr", "ing"}; + + std::string result = foldl1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/foldr1.cc b/lib/utils/test/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..3c9d9b66ae --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldr1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldr1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldr1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldr1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"ing", "tr", "a s"}; + + std::string result = foldr1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_permutations.cc b/lib/utils/test/src/utils/containers/get_all_permutations.cc index 5f22266809..cc5edb4075 100644 --- a/lib/utils/test/src/utils/containers/get_all_permutations.cc +++ b/lib/utils/test/src/utils/containers/get_all_permutations.cc @@ -1,8 +1,7 @@ #include "utils/containers/get_all_permutations.h" -#include "utils/containers/as_vector.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/vector.h" #include "utils/hash/vector.h" #include diff --git a/lib/utils/test/src/utils/containers/get_element_counts.cc b/lib/utils/test/src/utils/containers/get_element_counts.cc index 11e2ef7e05..8fc87dba90 100644 --- a/lib/utils/test/src/utils/containers/get_element_counts.cc +++ b/lib/utils/test/src/utils/containers/get_element_counts.cc @@ -1,5 +1,5 @@ #include "utils/containers/get_element_counts.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/inplace_filter.cc b/lib/utils/test/src/utils/containers/inplace_filter.cc index 7ef9d73339..ac430279b0 100644 --- a/lib/utils/test/src/utils/containers/inplace_filter.cc +++ b/lib/utils/test/src/utils/containers/inplace_filter.cc @@ -1,10 +1,11 @@ #include "utils/containers/inplace_filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/intersection.cc b/lib/utils/test/src/utils/containers/intersection.cc index ac9acf5e2b..52de6ee6d3 100644 --- a/lib/utils/test/src/utils/containers/intersection.cc +++ b/lib/utils/test/src/utils/containers/intersection.cc @@ -1,6 +1,6 @@ #include "utils/containers/intersection.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..71e7395805 --- /dev/null +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -0,0 +1,60 @@ +#include "utils/containers/maximum.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("maximum(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + + SUBCASE("input is empty") { + T input = {}; + + std::optional result = maximum(input); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input does not have duplicates") { + T input = {1, 3, 2}; + + std::optional result = maximum(input); + std::optional correct = 3; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + T input = {1, 2, 2, 0}; + + std::optional result = maximum(input); + std::optional correct = 2; + + CHECK(result == correct); + } + } + + TEST_CASE("maximum(std::vector)") { + std::vector input = {"hello", "world"}; + + std::optional result = maximum(input); + std::optional correct = "world"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/multiset_union.cc b/lib/utils/test/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..8c40bf55ab --- /dev/null +++ b/lib/utils/test/src/utils/containers/multiset_union.cc @@ -0,0 +1,29 @@ +#include "utils/containers/multiset_union.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("multiset_union(std::unordered_multiset, " + "std::unordered_multiset)") { + std::unordered_multiset input_lhs = {1, 2, 2, 3}; + std::unordered_multiset input_rhs = {1, 2, 5}; + + std::unordered_multiset result = multiset_union(input_lhs, input_rhs); + std::unordered_multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } + + TEST_CASE("multiset_union(std::multiset, std::multiset)") { + std::multiset input_lhs = {1, 2, 2, 3}; + std::multiset input_rhs = {1, 2, 5}; + + std::multiset result = multiset_union(input_lhs, input_rhs); + std::multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/repeat.cc b/lib/utils/test/src/utils/containers/repeat.cc index 50e4b3e7c5..d8ffe76a64 100644 --- a/lib/utils/test/src/utils/containers/repeat.cc +++ b/lib/utils/test/src/utils/containers/repeat.cc @@ -1,5 +1,5 @@ #include "utils/containers/repeat.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/require_no_duplicates.cc b/lib/utils/test/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..67733d791a --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1,62 @@ +#include "utils/containers/require_no_duplicates.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("require_no_duplicates(std::unordered_multiset)") { + SUBCASE("empty") { + std::unordered_multiset input = {}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::unordered_multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::unordered_multiset input = {1, 2, 4}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } + + TEST_CASE("require_no_duplicates(std::multiset)") { + SUBCASE("empty") { + std::multiset input = {}; + + std::set result = require_no_duplicates(input); + std::set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::multiset input = {1, 2, 4}; + + std::set result = require_no_duplicates(input); + std::set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/reversed.cc b/lib/utils/test/src/utils/containers/reversed.cc new file mode 100644 index 0000000000..834a497152 --- /dev/null +++ b/lib/utils/test/src/utils/containers/reversed.cc @@ -0,0 +1,27 @@ +#include "utils/containers/reversed.h" +#include "test/utils/doctest/fmt/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("reversed(std::vector)") { + SUBCASE("non-empty input") { + std::vector input = {1, 2, 3, 2}; + + std::vector result = reversed(input); + std::vector correct = {2, 3, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + + std::vector result = reversed(input); + std::vector correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/to_uppercase.cc b/lib/utils/test/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..9729307304 --- /dev/null +++ b/lib/utils/test/src/utils/containers/to_uppercase.cc @@ -0,0 +1,15 @@ +#include "utils/containers/to_uppercase.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("to_uppercase(std::string)") { + std::string input = "Hello World"; + + std::string result = to_uppercase(input); + std::string correct = "HELLO WORLD"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/transform.cc b/lib/utils/test/src/utils/containers/transform.cc index 916bc20928..3122c67117 100644 --- a/lib/utils/test/src/utils/containers/transform.cc +++ b/lib/utils/test/src/utils/containers/transform.cc @@ -1,7 +1,7 @@ #include "utils/containers/transform.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc index 6aeab4ae6e..b8a7a85f74 100644 --- a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc +++ b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc @@ -1,7 +1,7 @@ #include "utils/containers/try_merge_nondisjoint_unordered_maps.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc index 0ab0ef1446..becb7fdce0 100644 --- a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/containers/unordered_set_of.cc b/lib/utils/test/src/utils/containers/unordered_set_of.cc index d42b41dd50..b8ca1d1797 100644 --- a/lib/utils/test/src/utils/containers/unordered_set_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_set_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_set_of.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include #include diff --git a/lib/utils/test/src/utils/containers/vector_of.cc b/lib/utils/test/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..8b9353e1b0 --- /dev/null +++ b/lib/utils/test/src/utils/containers/vector_of.cc @@ -0,0 +1,17 @@ +#include "utils/containers/vector_of.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("vector_of(std::set)") { + std::set input = {2, 3, 1, 4}; + + std::vector result = vector_of(input); + std::vector correct = {1, 2, 3, 4}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/without_order.cc b/lib/utils/test/src/utils/containers/without_order.cc index 939c6ff108..b4c8663b14 100644 --- a/lib/utils/test/src/utils/containers/without_order.cc +++ b/lib/utils/test/src/utils/containers/without_order.cc @@ -1,5 +1,5 @@ #include "utils/containers/without_order.h" -#include "utils/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/expected.cc b/lib/utils/test/src/utils/expected.cc index 14679e0d13..3e5de13d49 100644 --- a/lib/utils/test/src/utils/expected.cc +++ b/lib/utils/test/src/utils/expected.cc @@ -1,6 +1,6 @@ #include "utils/expected.h" -#include "utils/fmt/expected.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/expected.cc b/lib/utils/test/src/utils/fmt/expected.cc index fb39732761..48df8634db 100644 --- a/lib/utils/test/src/utils/fmt/expected.cc +++ b/lib/utils/test/src/utils/fmt/expected.cc @@ -1,5 +1,6 @@ #include "utils/fmt/expected.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include using namespace ::FlexFlow; @@ -19,24 +20,4 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } - - TEST_CASE("doctest::toString(tl::expected)") { - SUBCASE("has expected") { - tl::expected input = 3; - - doctest::String result = doctest::toString(input); - doctest::String correct = "expected(3)"; - - CHECK(result == correct); - } - - SUBCASE("has unexpected") { - tl::expected input = tl::make_unexpected("error"); - - doctest::String result = doctest::toString(input); - doctest::String correct = "unexpected(error)"; - - CHECK(result == correct); - } - } } diff --git a/lib/utils/test/src/utils/fmt/map.cc b/lib/utils/test/src/utils/fmt/map.cc index b65b4791ea..19f3a7d5cf 100644 --- a/lib/utils/test/src/utils/fmt/map.cc +++ b/lib/utils/test/src/utils/fmt/map.cc @@ -1,5 +1,5 @@ #include "utils/fmt/map.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/optional.cc b/lib/utils/test/src/utils/fmt/optional.cc index e7815a26ac..1cd79da747 100644 --- a/lib/utils/test/src/utils/fmt/optional.cc +++ b/lib/utils/test/src/utils/fmt/optional.cc @@ -1,5 +1,5 @@ #include "utils/fmt/optional.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/pair.cc b/lib/utils/test/src/utils/fmt/pair.cc index 3d7cc78756..e848eb08c7 100644 --- a/lib/utils/test/src/utils/fmt/pair.cc +++ b/lib/utils/test/src/utils/fmt/pair.cc @@ -1,5 +1,5 @@ #include "utils/fmt/pair.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/set.cc b/lib/utils/test/src/utils/fmt/set.cc index 66824f2b2a..e317954b02 100644 --- a/lib/utils/test/src/utils/fmt/set.cc +++ b/lib/utils/test/src/utils/fmt/set.cc @@ -1,5 +1,5 @@ #include "utils/fmt/set.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_map.cc b/lib/utils/test/src/utils/fmt/unordered_map.cc index 99752d73f4..c980bc1e52 100644 --- a/lib/utils/test/src/utils/fmt/unordered_map.cc +++ b/lib/utils/test/src/utils/fmt/unordered_map.cc @@ -1,6 +1,7 @@ #include "utils/fmt/unordered_map.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include "utils/containers/get_element_counts.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_set.cc b/lib/utils/test/src/utils/fmt/unordered_set.cc index 9dc8d236f1..f492ea844d 100644 --- a/lib/utils/test/src/utils/fmt/unordered_set.cc +++ b/lib/utils/test/src/utils/fmt/unordered_set.cc @@ -1,7 +1,7 @@ #include "utils/fmt/unordered_set.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/variant.cc b/lib/utils/test/src/utils/fmt/variant.cc index 3ada166de9..0c8dca35d7 100644 --- a/lib/utils/test/src/utils/fmt/variant.cc +++ b/lib/utils/test/src/utils/fmt/variant.cc @@ -1,5 +1,5 @@ #include "utils/fmt/variant.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/vector.cc b/lib/utils/test/src/utils/fmt/vector.cc index fee3eb34a5..91ef6c9efc 100644 --- a/lib/utils/test/src/utils/fmt/vector.cc +++ b/lib/utils/test/src/utils/fmt/vector.cc @@ -1,5 +1,5 @@ #include "utils/fmt/vector.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/graph/cow_ptr_t.cc b/lib/utils/test/src/utils/graph/cow_ptr_t.cc index 65088c19de..e6a6f9661e 100644 --- a/lib/utils/test/src/utils/graph/cow_ptr_t.cc +++ b/lib/utils/test/src/utils/graph/cow_ptr_t.cc @@ -1,5 +1,5 @@ #include "utils/graph/cow_ptr_t.h" -#include "test/utils/doctest.h" +#include #include #include #include diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..330628adfd --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,43 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_subgraph_incoming_edges(DataflowGraphView, " + "std::unordered_set") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + std::unordered_set input_node_set = {n2, n3}; + + std::unordered_set result = + get_subgraph_incoming_edges(g, input_node_set); + + std::unordered_set correct = { + DataflowEdge{o1, DataflowInput{n2, 0}}, + DataflowEdge{o1, DataflowInput{n3, 0}}, + DataflowEdge{o1, DataflowInput{n3, 2}}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc index 7e02686dde..779d0a9560 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -7,7 +7,8 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_outgoing_edges(DataflowGraphView, std::unordered_set") { + TEST_CASE("get_subgraph_outgoing_edges(DataflowGraphView, " + "std::unordered_set") { DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc index cfc912af6b..7a3237d432 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc @@ -1,9 +1,11 @@ -#include "test/utils/doctest.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_output_query.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/node_query.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("UnorderedSetDataflowGraph") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 2ebfe232b6..eca7aa6c79 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,5 +1,8 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" +#include "utils/containers/reversed.h" +#include "utils/containers/vector_of.h" #include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include @@ -9,6 +12,25 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_cbc_decomposition") { DiGraph g = DiGraph::create(); + // used to check that the cbc decomposition result is the same regardless + // of the order in which the graph edges are processed, as this is a + // property that should hold, and violations of this property have been a + // source of bugs in the past + auto check_cbc_decomposition_is_edge_order_invariant = + [](DiGraphView const &g) { + std::unordered_set edges = get_edges(g); + + std::vector edge_order1 = vector_of(edges); + std::vector edge_order2 = reversed(edge_order1); + + std::optional result1 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order1); + std::optional result2 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order2); + + CHECK(result1 == result2); + }; + SUBCASE("six-node diamond graph") { std::vector n = add_nodes(g, 6); add_edges(g, @@ -32,6 +54,8 @@ TEST_SUITE(FF_TEST_SUITE) { }}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } SUBCASE("graph without any edges") { @@ -43,6 +67,27 @@ TEST_SUITE(FF_TEST_SUITE) { CompleteBipartiteCompositeDecomposition{{}}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); + } + + SUBCASE("irreducible n-graph (non-cbc graph)") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_cbc_decomposition(g); + std::optional correct = + std::nullopt; + + CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc new file mode 100644 index 0000000000..17c8b8da27 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc @@ -0,0 +1,175 @@ +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView, " + "std::unordered_set)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + SUBCASE("source group") { + std::unordered_set group1 = {n.at(0), n.at(1), n.at(2)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("sink group") { + std::unordered_set group1 = {n.at(3), n.at(4)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + SUBCASE("missing an edge (i.e., not complete)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge (i.e., not bipartite)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("flipped edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("group too small") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("missing an edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index fd2f469f93..a635658755 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -4,6 +4,7 @@ #include "utils/containers/transform.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" @@ -139,5 +140,27 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_bidict == correct_bidict); } } + + SUBCASE("sp n-graph (inverse line graph does not exist)") { + // Tests that the inverse line graph of the sp n-graph + // + // a-b + // \ + // c-d + // + // does not exist + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_inverse_line_graph(transitive_reduction(g)); + + CHECK_FALSE(result.has_value()); + } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc index 3ad506f40a..e675e6903f 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..5f72355ed0 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,50 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transitive_closure(DiGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("maximum number of new edges") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_closure(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc index b8a35346f4..1f9062a8ed 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" @@ -76,5 +77,66 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_edges == correct_edges); } } + + SUBCASE("longer paths") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + + SUBCASE("irreducible sp n-graph") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + CHECK(result_edges == correct_edges); + } + } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc new file mode 100644 index 0000000000..66b657eaaa --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc @@ -0,0 +1,51 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt GenericBinarySPDecompositionTree") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + std::string result = fmt::to_string(input); + std::string correct = ""; + + CHECK(result == correct); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + + std::string result = fmt::to_string(input); + std::string correct = (" " + "" + ">" + ">"); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + + std::string result = fmt::to_string(input); + std::string correct = (" " + "" + ">" + ">"); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..abae9286b6 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1,86 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_leaves(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5}; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 6}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 5}; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 6}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 5}; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5))), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(2))); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {2, 2, 4, 4, 5}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc new file mode 100644 index 0000000000..92c556ad28 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -0,0 +1,41 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_left_child(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + CHECK_THROWS(get_left_child(input)); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(3)); + + GenericBinarySPDecompositionTree result = get_left_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(5); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(7)); + + GenericBinarySPDecompositionTree result = get_left_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(4); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..3de61d3313 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1,85 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + int result = get_num_tree_nodes(input); + int correct = 1; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5))), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(2))); + + int result = get_num_tree_nodes(input); + int correct = 9; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc new file mode 100644 index 0000000000..33b5d37955 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -0,0 +1,41 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_right_child(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + CHECK_THROWS(get_right_child(input)); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(3)); + + GenericBinarySPDecompositionTree result = get_right_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(3); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(7)); + + GenericBinarySPDecompositionTree result = get_right_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(7); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc new file mode 100644 index 0000000000..e7025dbfad --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc @@ -0,0 +1,117 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree leaf_5 = + make_generic_binary_sp_leaf(5); + size_t leaf_5_hash = get_std_hash(leaf_5); + + SUBCASE("leaves with same labels hash to the same value") { + GenericBinarySPDecompositionTree also_leaf_5 = + make_generic_binary_sp_leaf(5); + size_t also_leaf_5_hash = get_std_hash(also_leaf_5); + + CHECK(leaf_5_hash == also_leaf_5_hash); + } + + SUBCASE("leaves with different labels hash to different values") { + GenericBinarySPDecompositionTree leaf_6 = + make_generic_binary_sp_leaf(6); + size_t leaf_6_hash = get_std_hash(leaf_6); + + CHECK(leaf_5_hash != leaf_6_hash); + } + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree series_5_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t series_5_6_hash = get_std_hash(series_5_6); + + SUBCASE("same children lead to the same hash") { + GenericBinarySPDecompositionTree also_series_5_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t also_series_5_6_hash = get_std_hash(also_series_5_6); + + CHECK(series_5_6_hash == also_series_5_6_hash); + } + + SUBCASE("hash is order dependent") { + GenericBinarySPDecompositionTree series_6_5 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(6), + make_generic_binary_sp_leaf(5)); + size_t series_6_5_hash = get_std_hash(series_6_5); + + CHECK(series_5_6_hash != series_6_5_hash); + } + + SUBCASE("different left child leads to different hash") { + GenericBinarySPDecompositionTree series_4_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(6)); + size_t series_4_6_hash = get_std_hash(series_4_6); + + CHECK(series_5_6_hash != series_4_6_hash); + } + + SUBCASE("different right child leads to different hash") { + GenericBinarySPDecompositionTree series_5_7 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + size_t series_5_7_hash = get_std_hash(series_5_7); + + CHECK(series_5_6_hash != series_5_7_hash); + } + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree parallel_5_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t parallel_5_6_hash = get_std_hash(parallel_5_6); + + SUBCASE("same children lead to the same hash") { + GenericBinarySPDecompositionTree also_parallel_5_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); + + CHECK(parallel_5_6_hash == also_parallel_5_6_hash); + } + + SUBCASE("hash is order dependent") { + GenericBinarySPDecompositionTree parallel_6_5 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), + make_generic_binary_sp_leaf(5)); + size_t parallel_6_5_hash = get_std_hash(parallel_6_5); + + CHECK(parallel_5_6_hash != parallel_6_5_hash); + } + + SUBCASE("different left child leads to different hash") { + GenericBinarySPDecompositionTree parallel_4_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(6)); + size_t parallel_4_6_hash = get_std_hash(parallel_4_6); + + CHECK(parallel_5_6_hash != parallel_4_6_hash); + } + + SUBCASE("different right child leads to different hash") { + GenericBinarySPDecompositionTree parallel_5_7 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + size_t parallel_5_7_hash = get_std_hash(parallel_5_7); + + CHECK(parallel_5_6_hash != parallel_5_7_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..7a8756c6cc --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,102 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_left_associative(" + "GenericBinarySPDecompositionTree)") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + SUBCASE("input is actually left associative") { + SUBCASE("just node") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(n1); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n3), + make_generic_binary_sp_leaf(n4))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not left associative") { + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..3cf87368ab --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,102 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_right_associative(" + "GenericBinarySPDecompositionTree)") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + SUBCASE("input is actually right associative") { + SUBCASE("just node") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(n1); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n3), + make_generic_binary_sp_leaf(n4))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not right associative") { + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc new file mode 100644 index 0000000000..cc234bacf8 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc @@ -0,0 +1,131 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer>") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree tt = make_generic_binary_sp_leaf(5); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree tt = + make_generic_binary_series_split(make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5)); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "series"}, + { + "value", + { + {"__type", "GenericBinarySeriesSplit"}, + { + "left_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }, + }, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree tt = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5)); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "parallel"}, + { + "value", + { + {"__type", "GenericBinaryParallelSplit"}, + { + "left_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }, + }, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..4ede4e84b5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1,28 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transform(GenericBinarySPDecompositionTree, F)") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split(make_generic_binary_sp_leaf(1), + make_generic_binary_sp_leaf(4)), + make_generic_binary_sp_leaf(2)); + + GenericBinarySPDecompositionTree result = + transform(input, [](int x) { return std::to_string(x); }); + + GenericBinarySPDecompositionTree correct = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(std::string{"1"}), + make_generic_binary_sp_leaf(std::string{"4"})), + make_generic_binary_sp_leaf(std::string{"2"})); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..1e3217a2de --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,95 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/rapidcheck.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("left_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf_node(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_series_split(make_leaf_node(n1), make_leaf_node(n2)), + make_leaf_node(n3)); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // left-associative binary SP trees + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..0befbde5cc --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,132 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nary_sp_tree_from_binary(BinarySPDecompositionTree)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = make_leaf_node(n1); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + + CHECK(result == correct); + } + + SUBCASE("left associative series") { + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf_node(n2), make_leaf_node(n1)), + make_leaf_node(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("right associative series") { + BinarySPDecompositionTree input = make_series_split( + make_leaf_node(n2), + make_series_split(make_leaf_node(n1), make_leaf_node(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("series with duplicate children") { + BinarySPDecompositionTree input = + make_series_split(make_leaf_node(n1), make_leaf_node(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n1, n1}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("left associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf_node(n2), make_leaf_node(n1)), + make_leaf_node(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("right associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_leaf_node(n2), + make_parallel_split(make_leaf_node(n1), make_leaf_node(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("parallel with duplicate children") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf_node(n1), make_leaf_node(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n1, n1}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split( + make_parallel_split( + make_leaf_node(n1), + make_series_split( + make_series_split(make_series_split(make_leaf_node(n2), + make_leaf_node(n3)), + make_leaf_node(n3)), + make_leaf_node(n5))), + make_series_split(make_leaf_node(n6), make_leaf_node(n4))), + make_leaf_node(n5)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..db1b440481 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,93 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("right_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf_node(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_leaf_node(n1), + make_series_split(make_leaf_node(n2), make_leaf_node(n3))); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // right-associative binary SP trees + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc similarity index 50% rename from lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc rename to lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 04d82bf1d8..45f796c824 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include @@ -6,47 +6,47 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_serial_parallel_decomposition (base case)") { + TEST_CASE("get_series_parallel_decomposition (base case)") { DiGraph g = DiGraph::create(); Node n = g.add_node(); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{n}; + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{n}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (parallel)") { + TEST_CASE("get_series_parallel_decomposition (parallel)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 2); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ParallelSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{ n.at(0), n.at(1), }}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (serial)") { + TEST_CASE("get_series_parallel_decomposition (serial)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 2); g.add_edge(DirectedEdge{n.at(0), n.at(1)}); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ n.at(0), n.at(1), }}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (composite)") { + TEST_CASE("get_series_parallel_decomposition (composite)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 3); add_edges(g, @@ -55,11 +55,11 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(0), n.at(2)}, }); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ n.at(0), ParallelSplit{ n.at(1), @@ -70,7 +70,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (diamond graph)") { + TEST_CASE("get_series_parallel_decomposition (diamond graph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 6); @@ -85,15 +85,15 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(4), n.at(5)}, }); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ n.at(0), ParallelSplit{ - SerialSplit{ + SeriesSplit{ n.at(1), n.at(3), }, - SerialSplit{ + SeriesSplit{ n.at(2), n.at(4), }, @@ -101,13 +101,13 @@ TEST_SUITE(FF_TEST_SUITE) { n.at(5), }}; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (all-to-all connection)") { + TEST_CASE("get_series_parallel_decomposition (all-to-all connection)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -120,9 +120,9 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(1), n.at(3)}, }); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ ParallelSplit{ n.at(0), n.at(1), @@ -134,13 +134,13 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (non-sp graph)") { + TEST_CASE("get_series_parallel_decomposition (non-sp graph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -153,9 +153,39 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(1), n.at(3)}, }); - std::optional correct = std::nullopt; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional correct = std::nullopt; + std::optional result = + get_series_parallel_decomposition(g); + + CHECK(result == correct); + } + + TEST_CASE( + "get_series_parallel_decomposition (requires transitive reduction)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ + n.at(0), + n.at(1), + n.at(2), + n.at(3), + }, + }; + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc similarity index 83% rename from lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc rename to lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 4560f95ff7..3a486c7094 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/fmt/variant.h" #include @@ -8,11 +8,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("flatten_ast") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{2}, Node{3}, @@ -25,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { flatten_ast(input); std::variant correct = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, Node{2}, diff --git a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc index 8259d256d3..a62f528bcf 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" #include "utils/graph/multidigraph/algorithms/add_nodes.h" diff --git a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 66% rename from lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 7cf17c3fee..f5766c9fdd 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,5 +1,5 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/fmt/unordered_set.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include using namespace ::FlexFlow; @@ -7,20 +7,20 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (base case)") { std::variant input = Node{1}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{Node{1}}; + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{Node{1}}; CHECK(result == correct); } TEST_CASE("to_final_ast (serial)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, {Node{1}, Node{2}}, }; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{ - SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{ Node{1}, Node{2}, }}, @@ -30,11 +30,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (composite)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{0}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ @@ -55,9 +55,9 @@ TEST_SUITE(FF_TEST_SUITE) { Node{5}, }}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = - SerialParallelDecomposition{SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{ Node{0}, Node{1}, ParallelSplit{{ @@ -70,55 +70,55 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("get_nodes(SerialParallelDecomposition)") { - SerialParallelDecomposition input = - SerialParallelDecomposition{SerialSplit{{ + TEST_CASE("get_nodes(SeriesParallelDecomposition)") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{ ParallelSplit{{ Node{1}, Node{2}, }}, - Node{3}, + Node{2}, ParallelSplit{{ Node{4}, Node{5}, }}, }}}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{2}, Node{4}, Node{5}, }; CHECK(result == correct); } - TEST_CASE("get_nodes(SerialSplit)") { + TEST_CASE("get_nodes(SeriesSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, ParallelSplit{{ Node{3}, Node{4}, }}, }}, - SerialSplit{{ - Node{5}, + SeriesSplit{{ + Node{1}, Node{6}, }}, Node{7}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, Node{3}, Node{4}, - Node{5}, + Node{1}, Node{6}, Node{7}, }; @@ -129,9 +129,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(ParallelSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, - Node{3}, + Node{4}, ParallelSplit{{ Node{4}, Node{5}, @@ -139,11 +139,11 @@ TEST_SUITE(FF_TEST_SUITE) { }}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{4}, Node{4}, Node{5}, }; @@ -153,8 +153,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(Node)") { Node input = Node{5}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = {input}; + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = {input}; CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index e4d53b4136..c6b45ec6ce 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/set_minus.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" diff --git a/lib/utils/test/src/utils/hash/multiset.cc b/lib/utils/test/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..5c2e01fda8 --- /dev/null +++ b/lib/utils/test/src/utils/hash/multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/hash/unordered_multiset.cc b/lib/utils/test/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..6c730fad3c --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::unordered_multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::unordered_multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::unordered_multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/json/optional.cc b/lib/utils/test/src/utils/json/optional.cc new file mode 100644 index 0000000000..61f5868c53 --- /dev/null +++ b/lib/utils/test/src/utils/json/optional.cc @@ -0,0 +1,49 @@ +#include "utils/json/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer>") { + SUBCASE("to_json") { + SUBCASE("has value") { + std::optional input = 5; + + nlohmann::json result = input; + nlohmann::json correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + std::optional input = std::nullopt; + + nlohmann::json result = input; + nlohmann::json correct = nullptr; + + CHECK(result == correct); + } + } + + SUBCASE("from_json") { + SUBCASE("has value") { + nlohmann::json input = 5; + + std::optional result = input; + std::optional correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + nlohmann::json input = nullptr; + + std::optional result = input.get>(); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/optional.cc b/lib/utils/test/src/utils/rapidcheck/optional.cc similarity index 67% rename from lib/utils/test/src/utils/optional.cc rename to lib/utils/test/src/utils/rapidcheck/optional.cc index 16c9e964cb..96b17a5400 100644 --- a/lib/utils/test/src/utils/optional.cc +++ b/lib/utils/test/src/utils/rapidcheck/optional.cc @@ -1,7 +1,8 @@ -#include "utils/optional.h" -#include "test/utils/doctest.h" +#include "utils/rapidcheck/optional.h" #include "test/utils/rapidcheck.h" -#include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE(