From b521bd0416ed659e6ac12be00b35c94d80918abe Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sat, 19 Aug 2023 09:13:49 +0000 Subject: [PATCH 01/17] add parse command line --- lib/runtime/include/runtime/config.h | 1 + lib/utils/include/utils/parse.h | 230 +++++++++++++++++++++++++++ lib/utils/test/src/test_parse.cc | 30 ++++ 3 files changed, 261 insertions(+) create mode 100644 lib/utils/include/utils/parse.h create mode 100644 lib/utils/test/src/test_parse.cc diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a7b8d86171..54fe9443f1 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -65,6 +65,7 @@ struct FFConfig : public use_visitable_cmp { FFConfig() = default; static Legion::MappingTagID get_hash_id(std::string const &pcname); + void parse_args(char **argv, int argc); public: int epochs = 1; diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h new file mode 100644 index 0000000000..20e548a584 --- /dev/null +++ b/lib/utils/include/utils/parse.h @@ -0,0 +1,230 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_PARSE_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_PARSE_H + +#include "runtime/config.h" +#include "utils/exception.h" +#include "utils/variant.h" +#include +#include +#include +namespace FlexFlow { + +using VariantType = variant; + +class ArgsParser { +private: + std::unordered_map mArgs; + std::unordered_map mDefaultValues; + std::unordered_map mDescriptions; + + std::string parseKey(std::string const &arg) const { + if (arg.substr(0, 2) == "--") { + return arg.substr(2); + } else { + return arg; + } + } + +public: + ArgsParser() = default; + void parse_args(int argc, char **argv) { + for (int i = 1; i < argc; i += 2) { + std::string key = parseKey(argv[i]); + if (key == "help" || key == "h") { + showDescriptions(); + exit(0); + } + mArgs[key] = argv[i + 1]; + } + } + + template + T get_from_variant(VariantType const &v) const; + + void add_argument(std::string const &key, + VariantType const &value, + std::string const &description) { + mDefaultValues[parseKey(key)] = std::move(value); + mDescriptions[key] = description; + } + + template + T get(std::string const &key) const { + auto it = mArgs.find(key); + if (it != mArgs.end()) { + return convert(it->second); + } else { + auto def_it = mDefaultValues.find(key); + if (def_it != mDefaultValues.end()) { + return get_from_variant(def_it->second); + } + } + throw mk_runtime_error("Key not found: " + key); + } + + void showDescriptions() const { + for (auto const &kv : mDescriptions) { + std::cout << kv.first << ": " << kv.second << std::endl; + } + } + + template + T convert(std::string const &s) const; + + friend std::ostream &operator<<(std::ostream &out, ArgsParser const &args); +}; + +template <> +int ArgsParser::convert(std::string const &s) const { + return std::stoi(s); +} + +template <> +float ArgsParser::convert(std::string const &s) const { + return std::stof(s); +} + +template <> +bool ArgsParser::convert(std::string const &s) const { + return s == "true" || s == "1"; +} + +template <> +std::string ArgsParser::convert(std::string const &s) const { + return s; +} + +template <> +int ArgsParser::get_from_variant(VariantType const &v) const { + return mpark::get(v); +} + +template <> +float ArgsParser::get_from_variant(VariantType const &v) const { + return mpark::get(v); +} + +template <> +bool ArgsParser::get_from_variant(VariantType const &v) const { + return mpark::get(v); +} + +template <> +std::string + ArgsParser::get_from_variant(VariantType const &v) const { + return mpark::get(v); +} + +std::ostream &operator<<(std::ostream &out, ArgsParser const &args) { + args.showDescriptions(); + return out; +} + +void FFConfig::parse_args(char **argv, int argc) { + ArgsParser args; + args.add_argument("--epochs", 1, "Number of epochs."); + args.add_argument("--batch-size", 32, "Size of each batch during training"); + args.add_argument( + "--learning-rate", 0.01f, "Learning rate for the optimizer"); + args.add_argument( + "--weight-decay", 0.0001f, "Weight decay for the optimizer"); + args.add_argument("--dataset-path", "", "Path to the dataset"); + args.add_argument("--search-budget", 0, "Search budget"); + args.add_argument("--search-alpha", 0.0f, "Search alpha"); + args.add_argument( + "--simulator-workspace-size", 0, "Simulator workspace size"); + args.add_argument("--only-data-parallel", false, "Only use data parallelism"); + args.add_argument( + "--enable-parameter-parallel", false, "Enable parameter parallelism"); + args.add_argument("--nodes", 1, "Number of nodes"); + args.add_argument("--profiling", false, "Enable profiling"); + args.add_argument("--allow-tensor-op-math-conversion", + false, + "Allow tensor op math conversion"); + args.add_argument("--fusion", false, "Enable fusion"); + args.add_argument("--overlap", false, "Enable overlap"); + args.add_argument( + "--taskgraph", "", "Export strategy computation graph file"); + args.add_argument( + "--include-costs-dot-graph", false, "Include costs dot graph"); + args.add_argument("--machine-model-version", 0, "Machine model version"); + args.add_argument("--machine-model-file", "", "Machine model file"); + args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); + args.add_argument( + "--simulator-max-num-segments", 0, "Simulator max number of segments"); + args.add_argument( + "--enable-inplace-optimizations", false, "Enable inplace optimizations"); + args.add_argument("--search-num-nodes", 0, "Search number of nodes"); + args.add_argument("--search-num-workers", 0, "Search number of workers"); + args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); + args.add_argument( + "--enable-control-replication", false, "Enable control replication"); + args.add_argument("--python-data-loader-type", 0, "Python data loader type"); + args.add_argument("--substitution-json", "", "Substitution json path"); + + // legion arguments + args.add_argument("-level", 5, "level of logging output"); + args.add_argument("-logfile", "", "name of log file"); + args.add_argument("-ll:cpu", 1, "CPUs per node"); + args.add_argument("-ll:gpu", 0, "GPUs per node"); + args.add_argument("-ll:util", 1, "utility processors to create per process"); + args.add_argument( + "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); + args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); + args.add_argument( + "-ll:rsize", + 0, + "size of GASNet registered RDMA memory available per process (in MB)"); + args.add_argument( + "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); + args.add_argument( + "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); + args.add_argument( + "-lg:window", + 8192, + "maximum number of tasks that can be created in a parent task window"); + args.add_argument("-lg:sched", + 1024, + " minimum number of tasks to try to schedule for each " + "invocation of the scheduler"); + + args.parse_args(argc, argv); + + batch_size = args.get("batch-size"); + epochs = args.get("epochs"); + learning_rate = args.get("learning-rate"); + weight_decay = args.get("weight-decay"); + dataset_path = args.get("dataset-path"); + search_budget = args.get("search-budget"); + search_alpha = args.get("search-alpha"); + simulator_work_space_size = args.get("simulator-workspace-size"); + only_data_parallel = args.get("only-data-parallel"); + enable_parameter_parallel = args.get("enable-parameter-parallel"); + numNodes = args.get("nodes"); + profiling = args.get("profiling"); + allow_tensor_op_math_conversion = + args.get("allow-tensor-op-math-conversion"); + perform_fusion = args.get("fusion"); + search_overlap_backward_update = args.get("overlap"); + export_strategy_computation_graph_file = args.get("--taskgraph"); + include_costs_dot_graph = args.get("include-costs-dot-graph"); + machine_model_version = args.get("machine-model-version"); + machine_model_file = args.get("machine-model-file"); + simulator_segment_size = args.get("simulator-segment-size"); + simulator_max_num_segments = args.get("simulator-max-num-segments"); + enable_inplace_optimizations = args.get("enable-inplace-optimizations"); + search_num_nodes = args.get("search-num-nodes"); + search_num_workers = args.get("search-num-workers"); + base_optimize_threshold = args.get("base-optimize-threshold"); + enable_control_replication = args.get("enable-control-replication"); + python_data_loader_type = args.get("python-data-loader-type"); + substitution_json_path = args.get("substitution-json"); + + // legion arguments + cpusPerNode = args.get("-ll:cpu"); + workersPerNode = args.get("-ll:gpu"); +} + +} // namespace FlexFlow + +#endif \ No newline at end of file diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc new file mode 100644 index 0000000000..4cf6c34162 --- /dev/null +++ b/lib/utils/test/src/test_parse.cc @@ -0,0 +1,30 @@ +#include "doctest.h" +#include "utils/parse.h" + +using namespace FlexFlow; + +TEST_CASE("Test ArgsParser basic functionality") { + char const *test_argv[] = {"program_name", + "--batch-size", + "100", + "--learning-rate", + "0.5", + "--fusion", + "true", + "-ll:gpus", + "6"}; + ArgsParser args; + args.add_argument("--batch-size", 32, "Size of each batch during training"); + args.add_argument( + "--learning-rate", 0.01f, "Learning rate for the optimizer"); + args.add_argument("--fusion", + false, + "Flag to determine if fusion optimization should be used"); + args.add_argument("-ll:gpus", 2, "Number of GPUs to be used for training"); + args.parse_args(9, const_cast(test_argv)); + + CHECK(args.get("batch-size") == 100); + CHECK(args.get("learning-rate") == 0.5f); + CHECK(args.get("fusion") == true); + CHECK(args.get("-ll:gpus") == 6); +} \ No newline at end of file From 5e5bc30344cead8a70664710df881c090b8b14e8 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sat, 19 Aug 2023 09:18:57 +0000 Subject: [PATCH 02/17] fix the fomat --- lib/utils/include/utils/parse.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index 20e548a584..7b201d22ee 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -227,4 +227,4 @@ void FFConfig::parse_args(char **argv, int argc) { } // namespace FlexFlow -#endif \ No newline at end of file +#endif From ca7f9a197f25bd105b5aa9184cf640c88d4bc9de Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sat, 19 Aug 2023 09:21:28 +0000 Subject: [PATCH 03/17] fix the fomat --- lib/utils/test/src/test_parse.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index 4cf6c34162..b3ee0d6dd2 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -27,4 +27,4 @@ TEST_CASE("Test ArgsParser basic functionality") { CHECK(args.get("learning-rate") == 0.5f); CHECK(args.get("fusion") == true); CHECK(args.get("-ll:gpus") == 6); -} \ No newline at end of file +} From 901db09d3682390baad5af7b015faa679daeedd5 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Mon, 4 Sep 2023 03:19:58 +0000 Subject: [PATCH 04/17] modify the parse --- lib/runtime/src/config.cc | 111 +++++++++++++++ lib/utils/include/utils/parse.h | 229 +++++-------------------------- lib/utils/test/src/test_parse.cc | 23 ++-- 3 files changed, 158 insertions(+), 205 deletions(-) create mode 100644 lib/runtime/src/config.cc diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc new file mode 100644 index 0000000000..b077d3d082 --- /dev/null +++ b/lib/runtime/src/config.cc @@ -0,0 +1,111 @@ +#include "config.h" +#include "utils/parser.h" + +namespace FlexFlow { + +// void FFConfig::parse_args(char **argv, int argc) { +// ArgsParser args; +// args.add_argument("--epochs", 1, "Number of epochs."); +// args.add_argument("--batch-size", 32, "Size of each batch during training"); +// args.add_argument( +// "--learning-rate", 0.01f, "Learning rate for the optimizer"); +// args.add_argument( +// "--weight-decay", 0.0001f, "Weight decay for the optimizer"); +// args.add_argument("--dataset-path", "", "Path to the dataset"); +// args.add_argument("--search-budget", 0, "Search budget"); +// args.add_argument("--search-alpha", 0.0f, "Search alpha"); +// args.add_argument( +// "--simulator-workspace-size", 0, "Simulator workspace size"); +// args.add_argument("--only-data-parallel", false, "Only use data parallelism"); +// args.add_argument( +// "--enable-parameter-parallel", false, "Enable parameter parallelism"); +// args.add_argument("--nodes", 1, "Number of nodes"); +// args.add_argument("--profiling", false, "Enable profiling"); +// args.add_argument("--allow-tensor-op-math-conversion", +// false, +// "Allow tensor op math conversion"); +// args.add_argument("--fusion", false, "Enable fusion"); +// args.add_argument("--overlap", false, "Enable overlap"); +// args.add_argument( +// "--taskgraph", "", "Export strategy computation graph file"); +// args.add_argument( +// "--include-costs-dot-graph", false, "Include costs dot graph"); +// args.add_argument("--machine-model-version", 0, "Machine model version"); +// args.add_argument("--machine-model-file", "", "Machine model file"); +// args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); +// args.add_argument( +// "--simulator-max-num-segments", 0, "Simulator max number of segments"); +// args.add_argument( +// "--enable-inplace-optimizations", false, "Enable inplace optimizations"); +// args.add_argument("--search-num-nodes", 0, "Search number of nodes"); +// args.add_argument("--search-num-workers", 0, "Search number of workers"); +// args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); +// args.add_argument( +// "--enable-control-replication", false, "Enable control replication"); +// args.add_argument("--python-data-loader-type", 0, "Python data loader type"); +// args.add_argument("--substitution-json", "", "Substitution json path"); + +// // legion arguments +// args.add_argument("-level", 5, "level of logging output"); +// args.add_argument("-logfile", "", "name of log file"); +// args.add_argument("-ll:cpu", 1, "CPUs per node"); +// args.add_argument("-ll:gpu", 0, "GPUs per node"); +// args.add_argument("-ll:util", 1, "utility processors to create per process"); +// args.add_argument( +// "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); +// args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); +// args.add_argument( +// "-ll:rsize", +// 0, +// "size of GASNet registered RDMA memory available per process (in MB)"); +// args.add_argument( +// "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); +// args.add_argument( +// "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); +// args.add_argument( +// "-lg:window", +// 8192, +// "maximum number of tasks that can be created in a parent task window"); +// args.add_argument("-lg:sched", +// 1024, +// " minimum number of tasks to try to schedule for each " +// "invocation of the scheduler"); + +// args.parse_args(argc, argv); + +// batch_size = args.get("batch-size"); +// epochs = args.get("epochs"); +// learning_rate = args.get("learning-rate"); +// weight_decay = args.get("weight-decay"); +// dataset_path = args.get("dataset-path"); +// search_budget = args.get("search-budget"); +// search_alpha = args.get("search-alpha"); +// simulator_work_space_size = args.get("simulator-workspace-size"); +// only_data_parallel = args.get("only-data-parallel"); +// enable_parameter_parallel = args.get("enable-parameter-parallel"); +// numNodes = args.get("nodes"); +// profiling = args.get("profiling"); +// allow_tensor_op_math_conversion = +// args.get("allow-tensor-op-math-conversion"); +// perform_fusion = args.get("fusion"); +// search_overlap_backward_update = args.get("overlap"); +// export_strategy_computation_graph_file = args.get("--taskgraph"); +// include_costs_dot_graph = args.get("include-costs-dot-graph"); +// machine_model_version = args.get("machine-model-version"); +// machine_model_file = args.get("machine-model-file"); +// simulator_segment_size = args.get("simulator-segment-size"); +// simulator_max_num_segments = args.get("simulator-max-num-segments"); +// enable_inplace_optimizations = args.get("enable-inplace-optimizations"); +// search_num_nodes = args.get("search-num-nodes"); +// search_num_workers = args.get("search-num-workers"); +// base_optimize_threshold = args.get("base-optimize-threshold"); +// enable_control_replication = args.get("enable-control-replication"); +// python_data_loader_type = args.get("python-data-loader-type"); +// substitution_json_path = args.get("substitution-json"); + +// // legion arguments +// cpusPerNode = args.get("-ll:cpu"); +// workersPerNode = args.get("-ll:gpu"); +// } + +} // namespace FlexFlow \ No newline at end of file diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index 7b201d22ee..d089f07967 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_PARSE_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_PARSE_H -#include "runtime/config.h" #include "utils/exception.h" #include "utils/variant.h" #include @@ -9,64 +8,55 @@ #include namespace FlexFlow { -using VariantType = variant; +using AllowedArgTypes = variant; + +std::string parseKey(std::string const &arg) const { + if (arg.substr(0, 2) == "--") { + return arg.substr(2); + } else { + return arg; + } +} class ArgsParser { private: std::unordered_map mArgs; - std::unordered_map mDefaultValues; + std::unordered_map mDefaultValues; std::unordered_map mDescriptions; - std::string parseKey(std::string const &arg) const { - if (arg.substr(0, 2) == "--") { - return arg.substr(2); - } else { - return arg; - } - } - public: ArgsParser() = default; - void parse_args(int argc, char **argv) { - for (int i = 1; i < argc; i += 2) { - std::string key = parseKey(argv[i]); - if (key == "help" || key == "h") { - showDescriptions(); - exit(0); - } - mArgs[key] = argv[i + 1]; + void parse_args(int argc, char **argv); + + template + class ArgumentReference { + public: + ArgumentReference(AllowedArgTypes const &defaultValue, + std::string const &description) + : defaultValue(defaultValue), description(description), key(key) {} + + AllowedArgTypes const &default_value() const { + return default_value; } - } + + private: + AllowedArgTypes defaultValue; + std::string description; + std::string key; + }; template - T get_from_variant(VariantType const &v) const; + T get_from_variant(AllowedArgTypes const &v) const; - void add_argument(std::string const &key, - VariantType const &value, - std::string const &description) { - mDefaultValues[parseKey(key)] = std::move(value); - mDescriptions[key] = description; - } + template + ArgumentReference add_argument(std::string const &key, + AllowedArgTypes const &value, + std::string const &description); template - T get(std::string const &key) const { - auto it = mArgs.find(key); - if (it != mArgs.end()) { - return convert(it->second); - } else { - auto def_it = mDefaultValues.find(key); - if (def_it != mDefaultValues.end()) { - return get_from_variant(def_it->second); - } - } - throw mk_runtime_error("Key not found: " + key); - } + T get(ArgumentReference const &arg_ref) const; - void showDescriptions() const { - for (auto const &kv : mDescriptions) { - std::cout << kv.first << ": " << kv.second << std::endl; - } - } + void showDescriptions() const; template T convert(std::string const &s) const; @@ -74,157 +64,6 @@ class ArgsParser { friend std::ostream &operator<<(std::ostream &out, ArgsParser const &args); }; -template <> -int ArgsParser::convert(std::string const &s) const { - return std::stoi(s); -} - -template <> -float ArgsParser::convert(std::string const &s) const { - return std::stof(s); -} - -template <> -bool ArgsParser::convert(std::string const &s) const { - return s == "true" || s == "1"; -} - -template <> -std::string ArgsParser::convert(std::string const &s) const { - return s; -} - -template <> -int ArgsParser::get_from_variant(VariantType const &v) const { - return mpark::get(v); -} - -template <> -float ArgsParser::get_from_variant(VariantType const &v) const { - return mpark::get(v); -} - -template <> -bool ArgsParser::get_from_variant(VariantType const &v) const { - return mpark::get(v); -} - -template <> -std::string - ArgsParser::get_from_variant(VariantType const &v) const { - return mpark::get(v); -} - -std::ostream &operator<<(std::ostream &out, ArgsParser const &args) { - args.showDescriptions(); - return out; -} - -void FFConfig::parse_args(char **argv, int argc) { - ArgsParser args; - args.add_argument("--epochs", 1, "Number of epochs."); - args.add_argument("--batch-size", 32, "Size of each batch during training"); - args.add_argument( - "--learning-rate", 0.01f, "Learning rate for the optimizer"); - args.add_argument( - "--weight-decay", 0.0001f, "Weight decay for the optimizer"); - args.add_argument("--dataset-path", "", "Path to the dataset"); - args.add_argument("--search-budget", 0, "Search budget"); - args.add_argument("--search-alpha", 0.0f, "Search alpha"); - args.add_argument( - "--simulator-workspace-size", 0, "Simulator workspace size"); - args.add_argument("--only-data-parallel", false, "Only use data parallelism"); - args.add_argument( - "--enable-parameter-parallel", false, "Enable parameter parallelism"); - args.add_argument("--nodes", 1, "Number of nodes"); - args.add_argument("--profiling", false, "Enable profiling"); - args.add_argument("--allow-tensor-op-math-conversion", - false, - "Allow tensor op math conversion"); - args.add_argument("--fusion", false, "Enable fusion"); - args.add_argument("--overlap", false, "Enable overlap"); - args.add_argument( - "--taskgraph", "", "Export strategy computation graph file"); - args.add_argument( - "--include-costs-dot-graph", false, "Include costs dot graph"); - args.add_argument("--machine-model-version", 0, "Machine model version"); - args.add_argument("--machine-model-file", "", "Machine model file"); - args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); - args.add_argument( - "--simulator-max-num-segments", 0, "Simulator max number of segments"); - args.add_argument( - "--enable-inplace-optimizations", false, "Enable inplace optimizations"); - args.add_argument("--search-num-nodes", 0, "Search number of nodes"); - args.add_argument("--search-num-workers", 0, "Search number of workers"); - args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); - args.add_argument( - "--enable-control-replication", false, "Enable control replication"); - args.add_argument("--python-data-loader-type", 0, "Python data loader type"); - args.add_argument("--substitution-json", "", "Substitution json path"); - - // legion arguments - args.add_argument("-level", 5, "level of logging output"); - args.add_argument("-logfile", "", "name of log file"); - args.add_argument("-ll:cpu", 1, "CPUs per node"); - args.add_argument("-ll:gpu", 0, "GPUs per node"); - args.add_argument("-ll:util", 1, "utility processors to create per process"); - args.add_argument( - "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); - args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); - args.add_argument( - "-ll:rsize", - 0, - "size of GASNet registered RDMA memory available per process (in MB)"); - args.add_argument( - "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); - args.add_argument( - "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); - args.add_argument( - "-lg:window", - 8192, - "maximum number of tasks that can be created in a parent task window"); - args.add_argument("-lg:sched", - 1024, - " minimum number of tasks to try to schedule for each " - "invocation of the scheduler"); - - args.parse_args(argc, argv); - - batch_size = args.get("batch-size"); - epochs = args.get("epochs"); - learning_rate = args.get("learning-rate"); - weight_decay = args.get("weight-decay"); - dataset_path = args.get("dataset-path"); - search_budget = args.get("search-budget"); - search_alpha = args.get("search-alpha"); - simulator_work_space_size = args.get("simulator-workspace-size"); - only_data_parallel = args.get("only-data-parallel"); - enable_parameter_parallel = args.get("enable-parameter-parallel"); - numNodes = args.get("nodes"); - profiling = args.get("profiling"); - allow_tensor_op_math_conversion = - args.get("allow-tensor-op-math-conversion"); - perform_fusion = args.get("fusion"); - search_overlap_backward_update = args.get("overlap"); - export_strategy_computation_graph_file = args.get("--taskgraph"); - include_costs_dot_graph = args.get("include-costs-dot-graph"); - machine_model_version = args.get("machine-model-version"); - machine_model_file = args.get("machine-model-file"); - simulator_segment_size = args.get("simulator-segment-size"); - simulator_max_num_segments = args.get("simulator-max-num-segments"); - enable_inplace_optimizations = args.get("enable-inplace-optimizations"); - search_num_nodes = args.get("search-num-nodes"); - search_num_workers = args.get("search-num-workers"); - base_optimize_threshold = args.get("base-optimize-threshold"); - enable_control_replication = args.get("enable-control-replication"); - python_data_loader_type = args.get("python-data-loader-type"); - substitution_json_path = args.get("substitution-json"); - - // legion arguments - cpusPerNode = args.get("-ll:cpu"); - workersPerNode = args.get("-ll:gpu"); -} - } // namespace FlexFlow #endif diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index b3ee0d6dd2..b44bdfc74a 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -14,17 +14,20 @@ TEST_CASE("Test ArgsParser basic functionality") { "-ll:gpus", "6"}; ArgsParser args; - args.add_argument("--batch-size", 32, "Size of each batch during training"); - args.add_argument( + auto batch_size_ref = args.add_argument( + "--batch-size", 32, "Size of each batch during training"); + auto learning_rate_ref = args.add_argument( "--learning-rate", 0.01f, "Learning rate for the optimizer"); - args.add_argument("--fusion", - false, - "Flag to determine if fusion optimization should be used"); - args.add_argument("-ll:gpus", 2, "Number of GPUs to be used for training"); + auto fusion_ref = args.add_argument( + "--fusion", + false, + "Flag to determine if fusion optimization should be used"); + auto ll_gpus_ref = args.add_argument( + "-ll:gpus", 2, "Number of GPUs to be used for training"); args.parse_args(9, const_cast(test_argv)); - CHECK(args.get("batch-size") == 100); - CHECK(args.get("learning-rate") == 0.5f); - CHECK(args.get("fusion") == true); - CHECK(args.get("-ll:gpus") == 6); + CHECK(args.get(batch_size_ref) == 100); + CHECK(args.get(learning_rate_ref) == 0.5f); + CHECK(args.get(fusion_ref) == true); + CHECK(args.get(ll_gpus_ref) == 6); } From 8f04a1a91704eed7f65b2b8e76ab935918a63d1d Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Tue, 5 Sep 2023 10:14:47 +0000 Subject: [PATCH 05/17] add config.cc, parse.cc and throw test --- lib/runtime/src/config.cc | 206 +++++++++++++++---------------- lib/utils/include/utils/parse.h | 64 +++++----- lib/utils/src/parse.cc | 49 ++++++++ lib/utils/test/src/test_parse.cc | 25 ++-- 4 files changed, 198 insertions(+), 146 deletions(-) create mode 100644 lib/utils/src/parse.cc diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc index b077d3d082..1827fc0145 100644 --- a/lib/runtime/src/config.cc +++ b/lib/runtime/src/config.cc @@ -1,111 +1,111 @@ -#include "config.h" -#include "utils/parser.h" +#include "runtime/config.h" +#include "utils/parse.h" namespace FlexFlow { -// void FFConfig::parse_args(char **argv, int argc) { -// ArgsParser args; -// args.add_argument("--epochs", 1, "Number of epochs."); -// args.add_argument("--batch-size", 32, "Size of each batch during training"); -// args.add_argument( -// "--learning-rate", 0.01f, "Learning rate for the optimizer"); -// args.add_argument( -// "--weight-decay", 0.0001f, "Weight decay for the optimizer"); -// args.add_argument("--dataset-path", "", "Path to the dataset"); -// args.add_argument("--search-budget", 0, "Search budget"); -// args.add_argument("--search-alpha", 0.0f, "Search alpha"); -// args.add_argument( -// "--simulator-workspace-size", 0, "Simulator workspace size"); -// args.add_argument("--only-data-parallel", false, "Only use data parallelism"); -// args.add_argument( -// "--enable-parameter-parallel", false, "Enable parameter parallelism"); -// args.add_argument("--nodes", 1, "Number of nodes"); -// args.add_argument("--profiling", false, "Enable profiling"); -// args.add_argument("--allow-tensor-op-math-conversion", -// false, -// "Allow tensor op math conversion"); -// args.add_argument("--fusion", false, "Enable fusion"); -// args.add_argument("--overlap", false, "Enable overlap"); -// args.add_argument( -// "--taskgraph", "", "Export strategy computation graph file"); -// args.add_argument( -// "--include-costs-dot-graph", false, "Include costs dot graph"); -// args.add_argument("--machine-model-version", 0, "Machine model version"); -// args.add_argument("--machine-model-file", "", "Machine model file"); -// args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); -// args.add_argument( -// "--simulator-max-num-segments", 0, "Simulator max number of segments"); -// args.add_argument( -// "--enable-inplace-optimizations", false, "Enable inplace optimizations"); -// args.add_argument("--search-num-nodes", 0, "Search number of nodes"); -// args.add_argument("--search-num-workers", 0, "Search number of workers"); -// args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); -// args.add_argument( -// "--enable-control-replication", false, "Enable control replication"); -// args.add_argument("--python-data-loader-type", 0, "Python data loader type"); -// args.add_argument("--substitution-json", "", "Substitution json path"); +void FFConfig::parse_args(char **argv, int argc) { + ArgsParser args; + args.add_argument("--epochs", 1, "Number of epochs."); + args.add_argument("--batch-size", 32, "Size of each batch during training"); + args.add_argument( + "--learning-rate", 0.01f, "Learning rate for the optimizer"); + args.add_argument( + "--weight-decay", 0.0001f, "Weight decay for the optimizer"); + args.add_argument("--dataset-path", "", "Path to the dataset"); + args.add_argument("--search-budget", 0, "Search budget"); + args.add_argument("--search-alpha", 0.0f, "Search alpha"); + args.add_argument( + "--simulator-workspace-size", 0, "Simulator workspace size"); + args.add_argument("--only-data-parallel", false, "Only use data parallelism"); + args.add_argument( + "--enable-parameter-parallel", false, "Enable parameter parallelism"); + args.add_argument("--nodes", 1, "Number of nodes"); + args.add_argument("--profiling", false, "Enable profiling"); + args.add_argument("--allow-tensor-op-math-conversion", + false, + "Allow tensor op math conversion"); + args.add_argument("--fusion", false, "Enable fusion"); + args.add_argument("--overlap", false, "Enable overlap"); + args.add_argument( + "--taskgraph", "", "Export strategy computation graph file"); + args.add_argument( + "--include-costs-dot-graph", false, "Include costs dot graph"); + args.add_argument("--machine-model-version", 0, "Machine model version"); + args.add_argument("--machine-model-file", "", "Machine model file"); + args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); + args.add_argument( + "--simulator-max-num-segments", 0, "Simulator max number of segments"); + args.add_argument( + "--enable-inplace-optimizations", false, "Enable inplace optimizations"); + args.add_argument("--search-num-nodes", 0, "Search number of nodes"); + args.add_argument("--search-num-workers", 0, "Search number of workers"); + args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); + args.add_argument( + "--enable-control-replication", false, "Enable control replication"); + args.add_argument("--python-data-loader-type", 0, "Python data loader type"); + args.add_argument("--substitution-json", "", "Substitution json path"); -// // legion arguments -// args.add_argument("-level", 5, "level of logging output"); -// args.add_argument("-logfile", "", "name of log file"); -// args.add_argument("-ll:cpu", 1, "CPUs per node"); -// args.add_argument("-ll:gpu", 0, "GPUs per node"); -// args.add_argument("-ll:util", 1, "utility processors to create per process"); -// args.add_argument( -// "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); -// args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); -// args.add_argument( -// "-ll:rsize", -// 0, -// "size of GASNet registered RDMA memory available per process (in MB)"); -// args.add_argument( -// "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); -// args.add_argument( -// "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); -// args.add_argument( -// "-lg:window", -// 8192, -// "maximum number of tasks that can be created in a parent task window"); -// args.add_argument("-lg:sched", -// 1024, -// " minimum number of tasks to try to schedule for each " -// "invocation of the scheduler"); + // legion arguments + args.add_argument("-level", 5, "level of logging output"); + args.add_argument("-logfile", "", "name of log file"); + args.add_argument("-ll:cpu", 1, "CPUs per node"); + args.add_argument("-ll:gpu", 0, "GPUs per node"); + args.add_argument("-ll:util", 1, "utility processors to create per process"); + args.add_argument( + "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); + args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); + args.add_argument( + "-ll:rsize", + 0, + "size of GASNet registered RDMA memory available per process (in MB)"); + args.add_argument( + "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); + args.add_argument( + "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); + args.add_argument( + "-lg:window", + 8192, + "maximum number of tasks that can be created in a parent task window"); + args.add_argument("-lg:sched", + 1024, + " minimum number of tasks to try to schedule for each " + "invocation of the scheduler"); -// args.parse_args(argc, argv); + args.parse_args(argc, argv); -// batch_size = args.get("batch-size"); -// epochs = args.get("epochs"); -// learning_rate = args.get("learning-rate"); -// weight_decay = args.get("weight-decay"); -// dataset_path = args.get("dataset-path"); -// search_budget = args.get("search-budget"); -// search_alpha = args.get("search-alpha"); -// simulator_work_space_size = args.get("simulator-workspace-size"); -// only_data_parallel = args.get("only-data-parallel"); -// enable_parameter_parallel = args.get("enable-parameter-parallel"); -// numNodes = args.get("nodes"); -// profiling = args.get("profiling"); -// allow_tensor_op_math_conversion = -// args.get("allow-tensor-op-math-conversion"); -// perform_fusion = args.get("fusion"); -// search_overlap_backward_update = args.get("overlap"); -// export_strategy_computation_graph_file = args.get("--taskgraph"); -// include_costs_dot_graph = args.get("include-costs-dot-graph"); -// machine_model_version = args.get("machine-model-version"); -// machine_model_file = args.get("machine-model-file"); -// simulator_segment_size = args.get("simulator-segment-size"); -// simulator_max_num_segments = args.get("simulator-max-num-segments"); -// enable_inplace_optimizations = args.get("enable-inplace-optimizations"); -// search_num_nodes = args.get("search-num-nodes"); -// search_num_workers = args.get("search-num-workers"); -// base_optimize_threshold = args.get("base-optimize-threshold"); -// enable_control_replication = args.get("enable-control-replication"); -// python_data_loader_type = args.get("python-data-loader-type"); -// substitution_json_path = args.get("substitution-json"); + batch_size = args.get("batch-size"); + epochs = args.get("epochs"); + learning_rate = args.get("learning-rate"); + weight_decay = args.get("weight-decay"); + dataset_path = args.get("dataset-path"); + search_budget = args.get("search-budget"); + search_alpha = args.get("search-alpha"); + simulator_work_space_size = args.get("simulator-workspace-size"); + only_data_parallel = args.get("only-data-parallel"); + enable_parameter_parallel = args.get("enable-parameter-parallel"); + numNodes = args.get("nodes"); + profiling = args.get("profiling"); + allow_tensor_op_math_conversion = + args.get("allow-tensor-op-math-conversion"); + perform_fusion = args.get("fusion"); + search_overlap_backward_update = args.get("overlap"); + export_strategy_computation_graph_file = args.get("--taskgraph"); + include_costs_dot_graph = args.get("include-costs-dot-graph"); + machine_model_version = args.get("machine-model-version"); + machine_model_file = args.get("machine-model-file"); + simulator_segment_size = args.get("simulator-segment-size"); + simulator_max_num_segments = args.get("simulator-max-num-segments"); + enable_inplace_optimizations = args.get("enable-inplace-optimizations"); + search_num_nodes = args.get("search-num-nodes"); + search_num_workers = args.get("search-num-workers"); + base_optimize_threshold = args.get("base-optimize-threshold"); + enable_control_replication = args.get("enable-control-replication"); + python_data_loader_type = args.get("python-data-loader-type"); + substitution_json_path = args.get("substitution-json"); -// // legion arguments -// cpusPerNode = args.get("-ll:cpu"); -// workersPerNode = args.get("-ll:gpu"); -// } + // legion arguments + cpusPerNode = args.get("-ll:cpu"); + workersPerNode = args.get("-ll:gpu"); +} -} // namespace FlexFlow \ No newline at end of file +} // namespace FlexFlow diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index d089f07967..6b4c34decf 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -1,22 +1,18 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_PARSE_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_PARSE_H +#include "utils/containers.h" #include "utils/exception.h" #include "utils/variant.h" #include #include #include + namespace FlexFlow { -using AllowedArgTypes = variant; +std::string parseKey(std::string arg); -std::string parseKey(std::string const &arg) const { - if (arg.substr(0, 2) == "--") { - return arg.substr(2); - } else { - return arg; - } -} +using AllowedArgTypes = variant; class ArgsParser { private: @@ -24,37 +20,45 @@ class ArgsParser { std::unordered_map mDefaultValues; std::unordered_map mDescriptions; + std::string parseKey(std::string const &arg) const { + if (arg.substr(0, 2) == "--") { + return arg.substr(2); + } else { + return arg; + } + } + public: ArgsParser() = default; - void parse_args(int argc, char **argv); - - template - class ArgumentReference { - public: - ArgumentReference(AllowedArgTypes const &defaultValue, - std::string const &description) - : defaultValue(defaultValue), description(description), key(key) {} - - AllowedArgTypes const &default_value() const { - return default_value; + void parse_args(int argc, char **argv) { + for (int i = 1; i < argc; i += 2) { + std::string key = parseKey(argv[i]); + if (key == "help" || key == "h") { + showDescriptions(); + exit(0); + } + mArgs[key] = argv[i + 1]; } - - private: - AllowedArgTypes defaultValue; - std::string description; - std::string key; - }; + } template T get_from_variant(AllowedArgTypes const &v) const; - template - ArgumentReference add_argument(std::string const &key, - AllowedArgTypes const &value, - std::string const &description); + void add_argument(std::string const &key, + AllowedArgTypes const &value, + std::string const &description); template - T get(ArgumentReference const &arg_ref) const; + T get(std::string const &key) const { + if (contains_key(mArgs, key)) { + return convert(mArgs.at(key)); + } else { + if (contains_key(mDefaultValues, key)) { + return mpark::get(mDefaultValues.at(key)); + } + } + throw mk_runtime_error("Key not found: " + key); + } void showDescriptions() const; diff --git a/lib/utils/src/parse.cc b/lib/utils/src/parse.cc new file mode 100644 index 0000000000..c3028168fe --- /dev/null +++ b/lib/utils/src/parse.cc @@ -0,0 +1,49 @@ +#include "utils/parse.h" +#include "utils/containers.h" +#include "utils/variant.h" + +namespace FlexFlow{ + +std::string parseKey(std::string arg) { + if (arg.substr(0, 2) == "--") { + return arg.substr(2); + } else { + return arg; + } + } + +void ArgsParser::add_argument(std::string const &key, + AllowedArgTypes const &value, + std::string const &description) { + mDefaultValues[parseKey(key)] = value; + mDescriptions[key] = description; + } + +template <> +int ArgsParser::convert(std::string const &s) const { + return std::stoi(s); +} + +template <> +float ArgsParser::convert(std::string const &s) const { + return std::stof(s); +} + +template <> +bool ArgsParser::convert(std::string const &s) const { + return s == "true" || s == "1" || s == "yes"; +} + +void ArgsParser::showDescriptions() const { + for (auto const &kv : mDescriptions) { + std::cerr << kv.first << ": " << kv.second << std::endl; + } + } + +//TODO(lambda):in the future, we will use fmt to format the output +std::ostream &operator<<(std::ostream &out, ArgsParser const &args) { + args.showDescriptions(); + return out; +} + +} // namespace FlexFlow diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index b44bdfc74a..caebb8c793 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -1,4 +1,5 @@ #include "doctest.h" +#include "utils/exception.h" #include "utils/parse.h" using namespace FlexFlow; @@ -14,20 +15,18 @@ TEST_CASE("Test ArgsParser basic functionality") { "-ll:gpus", "6"}; ArgsParser args; - auto batch_size_ref = args.add_argument( - "--batch-size", 32, "Size of each batch during training"); - auto learning_rate_ref = args.add_argument( + args.add_argument("--batch-size", 32, "Size of each batch during training"); + args.add_argument( "--learning-rate", 0.01f, "Learning rate for the optimizer"); - auto fusion_ref = args.add_argument( - "--fusion", - false, - "Flag to determine if fusion optimization should be used"); - auto ll_gpus_ref = args.add_argument( - "-ll:gpus", 2, "Number of GPUs to be used for training"); + args.add_argument("--fusion", + false, + "Flag to determine if fusion optimization should be used"); + args.add_argument("-ll:gpus", 2, "Number of GPUs to be used for training"); args.parse_args(9, const_cast(test_argv)); - CHECK(args.get(batch_size_ref) == 100); - CHECK(args.get(learning_rate_ref) == 0.5f); - CHECK(args.get(fusion_ref) == true); - CHECK(args.get(ll_gpus_ref) == 6); + CHECK(args.get("batch-size") == 100); + CHECK(args.get("learning-rate") == 0.5f); + CHECK(args.get("fusion") == true); + CHECK(args.get("-ll:gpus") == 6); + CHECK_THROWS(args.get("epochsss")); // throw exception } From 40517f7127ade15df70458e3f7705fe6e0bf237a Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Tue, 26 Sep 2023 15:33:30 +0000 Subject: [PATCH 06/17] support don't pass type --- lib/utils/include/utils/parse.h | 32 ++++++++---------------- lib/utils/src/parse.cc | 43 +++++++++++++++++++++----------- lib/utils/test/src/test_parse.cc | 43 ++++++++++++++++++++++++-------- 3 files changed, 71 insertions(+), 47 deletions(-) diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index 6b4c34decf..b17aa0dd60 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -14,39 +14,29 @@ std::string parseKey(std::string arg); using AllowedArgTypes = variant; +template +struct ArgRef { + std::string key; + T value; +}; + class ArgsParser { private: std::unordered_map mArgs; std::unordered_map mDefaultValues; std::unordered_map mDescriptions; - std::string parseKey(std::string const &arg) const { - if (arg.substr(0, 2) == "--") { - return arg.substr(2); - } else { - return arg; - } - } - public: ArgsParser() = default; - void parse_args(int argc, char **argv) { - for (int i = 1; i < argc; i += 2) { - std::string key = parseKey(argv[i]); - if (key == "help" || key == "h") { - showDescriptions(); - exit(0); - } - mArgs[key] = argv[i + 1]; - } - } + void parse_args(int argc, char **argv); template T get_from_variant(AllowedArgTypes const &v) const; - void add_argument(std::string const &key, - AllowedArgTypes const &value, - std::string const &description); + template + ArgRef add_argument(std::string const &key, + T const &value, + std::string const &description); template T get(std::string const &key) const { diff --git a/lib/utils/src/parse.cc b/lib/utils/src/parse.cc index c3028168fe..67c22b8a72 100644 --- a/lib/utils/src/parse.cc +++ b/lib/utils/src/parse.cc @@ -2,22 +2,35 @@ #include "utils/containers.h" #include "utils/variant.h" -namespace FlexFlow{ +namespace FlexFlow { std::string parseKey(std::string arg) { - if (arg.substr(0, 2) == "--") { - return arg.substr(2); - } else { - return arg; - } + if (arg.substr(0, 2) == "--") { + return arg.substr(2); + } else { + return arg; } +} -void ArgsParser::add_argument(std::string const &key, - AllowedArgTypes const &value, - std::string const &description) { - mDefaultValues[parseKey(key)] = value; - mDescriptions[key] = description; +void ArgsParser::parse_args(int argc, char **argv) { + for (int i = 1; i < argc; i += 2) { + std::string key = parseKey(argv[i]); + if (key == "help" || key == "h") { + showDescriptions(); + exit(0); + } + mArgs[key] = argv[i + 1]; } +} + +template +ArgRef ArgsParser::add_argument(std::string const &key, + T const &value, + std::string const &description) { + mDefaultValues[parseKey(key)] = value; + mDescriptions[key] = description; + return ArgRef{key, value}; +} template <> int ArgsParser::convert(std::string const &s) const { @@ -35,12 +48,12 @@ bool ArgsParser::convert(std::string const &s) const { } void ArgsParser::showDescriptions() const { - for (auto const &kv : mDescriptions) { - std::cerr << kv.first << ": " << kv.second << std::endl; - } + for (auto const &kv : mDescriptions) { + std::cerr << *this << std::endl; } +} -//TODO(lambda):in the future, we will use fmt to format the output +// TODO(lambda):in the future, we will use fmt to format the output std::ostream &operator<<(std::ostream &out, ArgsParser const &args) { args.showDescriptions(); return out; diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index caebb8c793..5e23f4ebec 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -1,6 +1,7 @@ #include "doctest.h" #include "utils/exception.h" #include "utils/parse.h" +#include using namespace FlexFlow; @@ -15,18 +16,38 @@ TEST_CASE("Test ArgsParser basic functionality") { "-ll:gpus", "6"}; ArgsParser args; - args.add_argument("--batch-size", 32, "Size of each batch during training"); - args.add_argument( + auto batch_size_ref = + args.add_argument("--batch-size", 32, "batch size for training"); + auto learning_rate_ref = args.add_argument( "--learning-rate", 0.01f, "Learning rate for the optimizer"); - args.add_argument("--fusion", - false, - "Flag to determine if fusion optimization should be used"); - args.add_argument("-ll:gpus", 2, "Number of GPUs to be used for training"); + auto fusion_ref = args.add_argument( + "--fusion", + "yes", + "Flag to determine if fusion optimization should be used"); + auto ll_gpus_ref = args.add_argument( + "-ll:gpus", 2, "Number of GPUs to be used for training"); args.parse_args(9, const_cast(test_argv)); - CHECK(args.get("batch-size") == 100); - CHECK(args.get("learning-rate") == 0.5f); - CHECK(args.get("fusion") == true); - CHECK(args.get("-ll:gpus") == 6); - CHECK_THROWS(args.get("epochsss")); // throw exception + CHECK(args.get(batch_size_ref) == 100); + + CHECK(args.get(learning_rate_ref) == 0.5f); + CHECK(args.ge(fusion_ref) == true); + CHECK(args.get(ll_gpus_ref) == 6); + CHECK_THROWS(args.get("epochsss")); // throw exception +} + +TEST_CASE("Test size and invalid") { + char const *test_argv[] = { + "program_name", "batch-size", "100", "--fusion", "true"}; + ArgsParser args; + auto batch_size_ref = + args.add_argument("batch-size", 32, "batch size for training"); + auto fusion_ref = args.add_argument( + "--fusion", + "0", + "Flag to determine if fusion optimization should be used"); + args.parse_args(9, const_cast(test_argv)); + + CHECK(args.get(fusion_ref) == false); + CHECK(args.get(batch_size_ref) == 100); } From 32359ff9fdb75e081a78035c1399a2980d1d87f8 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Tue, 26 Sep 2023 15:38:17 +0000 Subject: [PATCH 07/17] add invalid command --- lib/utils/test/src/test_parse.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index 5e23f4ebec..5058a3894e 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -33,12 +33,14 @@ TEST_CASE("Test ArgsParser basic functionality") { CHECK(args.get(learning_rate_ref) == 0.5f); CHECK(args.ge(fusion_ref) == true); CHECK(args.get(ll_gpus_ref) == 6); - CHECK_THROWS(args.get("epochsss")); // throw exception + ArgRef invalid_ref; + CHECK_THROWS( + args.get(invalid_ref)); // throw exception because it's invalid ref } -TEST_CASE("Test size and invalid") { +TEST_CASE("Test batch and fusioon=0") { char const *test_argv[] = { - "program_name", "batch-size", "100", "--fusion", "true"}; + "program_name", "batch-size", "100", "--fusion", "yes"}; ArgsParser args; auto batch_size_ref = args.add_argument("batch-size", 32, "batch size for training"); @@ -48,6 +50,6 @@ TEST_CASE("Test size and invalid") { "Flag to determine if fusion optimization should be used"); args.parse_args(9, const_cast(test_argv)); - CHECK(args.get(fusion_ref) == false); + CHECK(args.get(fusion_ref) == true); CHECK(args.get(batch_size_ref) == 100); } From ff9787549193574db21f405264589feec371d995 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Tue, 26 Sep 2023 15:45:34 +0000 Subject: [PATCH 08/17] rename the ArgRef to CommandlineRef --- lib/utils/include/utils/parse.h | 4 ++-- lib/utils/src/parse.cc | 4 ++-- lib/utils/test/src/test_parse.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index b17aa0dd60..23d2bd2d79 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -15,7 +15,7 @@ std::string parseKey(std::string arg); using AllowedArgTypes = variant; template -struct ArgRef { +struct CmdlineArgRef { std::string key; T value; }; @@ -34,7 +34,7 @@ class ArgsParser { T get_from_variant(AllowedArgTypes const &v) const; template - ArgRef add_argument(std::string const &key, + CmdlineArgRef add_argument(std::string const &key, T const &value, std::string const &description); diff --git a/lib/utils/src/parse.cc b/lib/utils/src/parse.cc index 67c22b8a72..8c445f061c 100644 --- a/lib/utils/src/parse.cc +++ b/lib/utils/src/parse.cc @@ -24,12 +24,12 @@ void ArgsParser::parse_args(int argc, char **argv) { } template -ArgRef ArgsParser::add_argument(std::string const &key, +CmdlineArgRef ArgsParser::add_argument(std::string const &key, T const &value, std::string const &description) { mDefaultValues[parseKey(key)] = value; mDescriptions[key] = description; - return ArgRef{key, value}; + return CmdlineArgRef{key, value}; } template <> diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index 5058a3894e..ba6c1ad21f 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -33,7 +33,7 @@ TEST_CASE("Test ArgsParser basic functionality") { CHECK(args.get(learning_rate_ref) == 0.5f); CHECK(args.ge(fusion_ref) == true); CHECK(args.get(ll_gpus_ref) == 6); - ArgRef invalid_ref; + CmdlineArgRef invalid_ref; CHECK_THROWS( args.get(invalid_ref)); // throw exception because it's invalid ref } From 134278afbae4ec047be63b9e731aac65be9a80f2 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Tue, 26 Sep 2023 15:46:12 +0000 Subject: [PATCH 09/17] fix some typo --- lib/utils/test/src/test_parse.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index ba6c1ad21f..4660dc3a4e 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -38,7 +38,7 @@ TEST_CASE("Test ArgsParser basic functionality") { args.get(invalid_ref)); // throw exception because it's invalid ref } -TEST_CASE("Test batch and fusioon=0") { +TEST_CASE("Test batch and fusion set 0") { char const *test_argv[] = { "program_name", "batch-size", "100", "--fusion", "yes"}; ArgsParser args; From 8a82efd2480d02d04b24a3747647be0ac92142aa Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Tue, 26 Sep 2023 16:03:45 +0000 Subject: [PATCH 10/17] use the new api for config --- lib/runtime/src/config.cc | 140 +++++++++++++++++++------------------- 1 file changed, 70 insertions(+), 70 deletions(-) diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc index 1827fc0145..0c8b118f0f 100644 --- a/lib/runtime/src/config.cc +++ b/lib/runtime/src/config.cc @@ -5,107 +5,107 @@ namespace FlexFlow { void FFConfig::parse_args(char **argv, int argc) { ArgsParser args; - args.add_argument("--epochs", 1, "Number of epochs."); - args.add_argument("--batch-size", 32, "Size of each batch during training"); - args.add_argument( + auto epochs_ref = args.add_argument("--epochs", 1, "Number of epochs."); + auto batch_size_ref = args.add_argument("--batch-size", 32, "Size of each batch during training"); + auto learning_rate_ref = args.add_argument( "--learning-rate", 0.01f, "Learning rate for the optimizer"); - args.add_argument( + auto weight_decay_ref = args.add_argument( "--weight-decay", 0.0001f, "Weight decay for the optimizer"); - args.add_argument("--dataset-path", "", "Path to the dataset"); - args.add_argument("--search-budget", 0, "Search budget"); - args.add_argument("--search-alpha", 0.0f, "Search alpha"); - args.add_argument( + auto dataset_pat_ref = args.add_argument("--dataset-path", "", "Path to the dataset"); + auto search_budget_ref = args.add_argument("--search-budget", 0, "Search budget"); + auto search_alpha_ref = args.add_argument("--search-alpha", 0.0f, "Search alpha"); + auto simulator_workspace_size_ref = args.add_argument( "--simulator-workspace-size", 0, "Simulator workspace size"); - args.add_argument("--only-data-parallel", false, "Only use data parallelism"); - args.add_argument( + auto only_data_parallel_ref = args.add_argument("--only-data-parallel", false, "Only use data parallelism"); + auto enable_parameter_parallel = args.add_argument( "--enable-parameter-parallel", false, "Enable parameter parallelism"); - args.add_argument("--nodes", 1, "Number of nodes"); - args.add_argument("--profiling", false, "Enable profiling"); - args.add_argument("--allow-tensor-op-math-conversion", + auto nodes_ref = args.add_argument("--nodes", 1, "Number of nodes"); + auto profiling_ref = args.add_argument("--profiling", false, "Enable profiling"); + auto allow_tensor_op_math_conversion_ref = args.add_argument("--allow-tensor-op-math-conversion", false, "Allow tensor op math conversion"); - args.add_argument("--fusion", false, "Enable fusion"); - args.add_argument("--overlap", false, "Enable overlap"); - args.add_argument( + auto fustion_ref = args.add_argument("--fusion", false, "Enable fusion"); + auto overlap_ref = args.add_argument("--overlap", false, "Enable overlap"); + auto taskgraph_ref = args.add_argument( "--taskgraph", "", "Export strategy computation graph file"); - args.add_argument( + auto = include_costs_dot_graph_ref = args.add_argument( "--include-costs-dot-graph", false, "Include costs dot graph"); - args.add_argument("--machine-model-version", 0, "Machine model version"); - args.add_argument("--machine-model-file", "", "Machine model file"); - args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); - args.add_argument( + auto machine_model_version_ref = args.add_argument("--machine-model-version", 0, "Machine model version"); + auto machine_model_file_ref = args.add_argument("--machine-model-file", "", "Machine model file"); + auto simulator_segment_size_ref = args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); + auto simulator_max_num_segments_ref = args.add_argument( "--simulator-max-num-segments", 0, "Simulator max number of segments"); - args.add_argument( + auto enable_inplace_optimizations_ref = args.add_argument( "--enable-inplace-optimizations", false, "Enable inplace optimizations"); - args.add_argument("--search-num-nodes", 0, "Search number of nodes"); - args.add_argument("--search-num-workers", 0, "Search number of workers"); - args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); - args.add_argument( + auto search_num_nodes_ref = args.add_argument("--search-num-nodes", 0, "Search number of nodes"); + auto search_num_workers_ref= args.add_argument("--search-num-workers", 0, "Search number of workers"); + auto base_optimize_threshold_ref = args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); + auto enable_control_replication_ref = args.add_argument( "--enable-control-replication", false, "Enable control replication"); - args.add_argument("--python-data-loader-type", 0, "Python data loader type"); - args.add_argument("--substitution-json", "", "Substitution json path"); + auto python_data_loader_type_ref = args.add_argument("--python-data-loader-type", 0, "Python data loader type"); + auto substitution_json_ref = args.add_argument("--substitution-json", "", "Substitution json path"); // legion arguments - args.add_argument("-level", 5, "level of logging output"); - args.add_argument("-logfile", "", "name of log file"); - args.add_argument("-ll:cpu", 1, "CPUs per node"); - args.add_argument("-ll:gpu", 0, "GPUs per node"); - args.add_argument("-ll:util", 1, "utility processors to create per process"); - args.add_argument( + auto level_ref = args.add_argument("-level", 5, "level of logging output"); + auto logfile_ref = args.add_argument("-logfile", "", "name of log file"); + auto ll_cpu_ref = args.add_argument("-ll:cpu", 1, "CPUs per node"); + auto ll_gpu_ref =args.add_argument("-ll:gpu", 0, "GPUs per node"); + auto ll_util_ref = args.add_argument("-ll:util", 1, "utility processors to create per process"); + auto ll_csize_ref = args.add_argument( "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); - args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); - args.add_argument( + auto ll_gsize_ref = args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); + auto ll_rsize_ref = args.add_argument( "-ll:rsize", 0, "size of GASNet registered RDMA memory available per process (in MB)"); - args.add_argument( + auto ll_fsize_ref = args.add_argument( "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); - args.add_argument( + auto ll_zsize_ref = args.add_argument( "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); - args.add_argument( + auto lg_window_ref =args.add_argument( "-lg:window", 8192, "maximum number of tasks that can be created in a parent task window"); - args.add_argument("-lg:sched", + auto lg_sched_ref = args.add_argument("-lg:sched", 1024, " minimum number of tasks to try to schedule for each " "invocation of the scheduler"); args.parse_args(argc, argv); - batch_size = args.get("batch-size"); - epochs = args.get("epochs"); - learning_rate = args.get("learning-rate"); - weight_decay = args.get("weight-decay"); - dataset_path = args.get("dataset-path"); - search_budget = args.get("search-budget"); - search_alpha = args.get("search-alpha"); - simulator_work_space_size = args.get("simulator-workspace-size"); - only_data_parallel = args.get("only-data-parallel"); - enable_parameter_parallel = args.get("enable-parameter-parallel"); - numNodes = args.get("nodes"); - profiling = args.get("profiling"); + batch_size = args.get(batch_size_ref) + epochs = args.get(epochs_ref); + learning_rate = args.get(learning_rate_ref); + weight_decay = args.get(weight_decay_ref); + dataset_path = args.get(dataset_pat_ref); + search_budget = args.get(search_budget_ref); + search_alpha = args.get(search_alpha_ref); + simulator_work_space_size = args.get(simulator_workspace_size_ref)); + only_data_parallel = args.get(only_data_parallel_ref); + enable_parameter_parallel = args.get(enable_parameter_parallel); + numNodes = args.get(nodes_ref); + profiling = args.get(profiling_ref); allow_tensor_op_math_conversion = - args.get("allow-tensor-op-math-conversion"); - perform_fusion = args.get("fusion"); - search_overlap_backward_update = args.get("overlap"); - export_strategy_computation_graph_file = args.get("--taskgraph"); - include_costs_dot_graph = args.get("include-costs-dot-graph"); - machine_model_version = args.get("machine-model-version"); - machine_model_file = args.get("machine-model-file"); - simulator_segment_size = args.get("simulator-segment-size"); - simulator_max_num_segments = args.get("simulator-max-num-segments"); - enable_inplace_optimizations = args.get("enable-inplace-optimizations"); - search_num_nodes = args.get("search-num-nodes"); - search_num_workers = args.get("search-num-workers"); - base_optimize_threshold = args.get("base-optimize-threshold"); - enable_control_replication = args.get("enable-control-replication"); - python_data_loader_type = args.get("python-data-loader-type"); - substitution_json_path = args.get("substitution-json"); + args.get(allow_tensor_op_math_conversion_ref); + perform_fusion = args.get(fustion_ref); + search_overlap_backward_update = args.get(overlap_ref); + export_strategy_computation_graph_file = args.get(task_graph_ref); + include_costs_dot_graph = args.get(include_costs_dot_graph_ref); + machine_model_version = args.get(machine_model_version_ref); + machine_model_file = args.get(machine_model_file_ref); + simulator_segment_size = args.get(simulator_segment_size_ref); + simulator_max_num_segments = args.get(simulator_max_num_segments_ref); + enable_inplace_optimizations = args.get(enable_inplace_optimizations_ref); + search_num_nodes = args.get(search_num_nodes_ref); + search_num_workers = args.get(search_num_workers_ref); + base_optimize_threshold = args.get(base_optimize_threshold_ref); + enable_control_replication = args.get(enable_control_replication_ref); + python_data_loader_type = args.get(python_data_loader_type_ref); + substitution_json_path = args.get(substitution_json_ref); // legion arguments - cpusPerNode = args.get("-ll:cpu"); - workersPerNode = args.get("-ll:gpu"); + cpusPerNode = args.get(ll_cpu_ref); + workersPerNode = args.get(ll_gpu_ref); } } // namespace FlexFlow From 54ff318a82ee390fcf1cea7411bf28673143907c Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Thu, 28 Sep 2023 21:09:12 +0000 Subject: [PATCH 11/17] modify the parse --- lib/utils/include/utils/parse.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index 23d2bd2d79..24e18b4f05 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -30,9 +30,6 @@ class ArgsParser { ArgsParser() = default; void parse_args(int argc, char **argv); - template - T get_from_variant(AllowedArgTypes const &v) const; - template CmdlineArgRef add_argument(std::string const &key, T const &value, From ce78d00dfd41fffb71bc605952b9a95f20f53344 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Fri, 29 Sep 2023 16:36:09 +0000 Subject: [PATCH 12/17] refine the pass and make the class data-only --- lib/runtime/src/config.cc | 71 ++++++++++++++++++------------ lib/utils/include/utils/parse.h | 75 +++++++++++++++++--------------- lib/utils/src/parse.cc | 68 +++++++++++++++++------------ lib/utils/test/src/test_parse.cc | 71 ++++++++++++++++-------------- 4 files changed, 162 insertions(+), 123 deletions(-) diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc index 0c8b118f0f..85e7d12bb7 100644 --- a/lib/runtime/src/config.cc +++ b/lib/runtime/src/config.cc @@ -6,54 +6,71 @@ namespace FlexFlow { void FFConfig::parse_args(char **argv, int argc) { ArgsParser args; auto epochs_ref = args.add_argument("--epochs", 1, "Number of epochs."); - auto batch_size_ref = args.add_argument("--batch-size", 32, "Size of each batch during training"); + auto batch_size_ref = args.add_argument( + "--batch-size", 32, "Size of each batch during training"); auto learning_rate_ref = args.add_argument( "--learning-rate", 0.01f, "Learning rate for the optimizer"); auto weight_decay_ref = args.add_argument( "--weight-decay", 0.0001f, "Weight decay for the optimizer"); - auto dataset_pat_ref = args.add_argument("--dataset-path", "", "Path to the dataset"); - auto search_budget_ref = args.add_argument("--search-budget", 0, "Search budget"); - auto search_alpha_ref = args.add_argument("--search-alpha", 0.0f, "Search alpha"); + auto dataset_pat_ref = + args.add_argument("--dataset-path", "", "Path to the dataset"); + auto search_budget_ref = + args.add_argument("--search-budget", 0, "Search budget"); + auto search_alpha_ref = + args.add_argument("--search-alpha", 0.0f, "Search alpha"); auto simulator_workspace_size_ref = args.add_argument( "--simulator-workspace-size", 0, "Simulator workspace size"); - auto only_data_parallel_ref = args.add_argument("--only-data-parallel", false, "Only use data parallelism"); + auto only_data_parallel_ref = args.add_argument( + "--only-data-parallel", false, "Only use data parallelism"); auto enable_parameter_parallel = args.add_argument( "--enable-parameter-parallel", false, "Enable parameter parallelism"); auto nodes_ref = args.add_argument("--nodes", 1, "Number of nodes"); - auto profiling_ref = args.add_argument("--profiling", false, "Enable profiling"); - auto allow_tensor_op_math_conversion_ref = args.add_argument("--allow-tensor-op-math-conversion", - false, - "Allow tensor op math conversion"); + auto profiling_ref = + args.add_argument("--profiling", false, "Enable profiling"); + auto allow_tensor_op_math_conversion_ref = + args.add_argument("--allow-tensor-op-math-conversion", + false, + "Allow tensor op math conversion"); auto fustion_ref = args.add_argument("--fusion", false, "Enable fusion"); auto overlap_ref = args.add_argument("--overlap", false, "Enable overlap"); auto taskgraph_ref = args.add_argument( "--taskgraph", "", "Export strategy computation graph file"); auto = include_costs_dot_graph_ref = args.add_argument( "--include-costs-dot-graph", false, "Include costs dot graph"); - auto machine_model_version_ref = args.add_argument("--machine-model-version", 0, "Machine model version"); - auto machine_model_file_ref = args.add_argument("--machine-model-file", "", "Machine model file"); - auto simulator_segment_size_ref = args.add_argument("--simulator-segment-size", 0, "Simulator segment size"); + auto machine_model_version_ref = + args.add_argument("--machine-model-version", 0, "Machine model version"); + auto machine_model_file_ref = + args.add_argument("--machine-model-file", "", "Machine model file"); + auto simulator_segment_size_ref = args.add_argument( + "--simulator-segment-size", 0, "Simulator segment size"); auto simulator_max_num_segments_ref = args.add_argument( "--simulator-max-num-segments", 0, "Simulator max number of segments"); auto enable_inplace_optimizations_ref = args.add_argument( "--enable-inplace-optimizations", false, "Enable inplace optimizations"); - auto search_num_nodes_ref = args.add_argument("--search-num-nodes", 0, "Search number of nodes"); - auto search_num_workers_ref= args.add_argument("--search-num-workers", 0, "Search number of workers"); - auto base_optimize_threshold_ref = args.add_argument("--base-optimize-threshold", 0, "Base optimize threshold"); + auto search_num_nodes_ref = + args.add_argument("--search-num-nodes", 0, "Search number of nodes"); + auto search_num_workers_ref = + args.add_argument("--search-num-workers", 0, "Search number of workers"); + auto base_optimize_threshold_ref = args.add_argument( + "--base-optimize-threshold", 0, "Base optimize threshold"); auto enable_control_replication_ref = args.add_argument( "--enable-control-replication", false, "Enable control replication"); - auto python_data_loader_type_ref = args.add_argument("--python-data-loader-type", 0, "Python data loader type"); - auto substitution_json_ref = args.add_argument("--substitution-json", "", "Substitution json path"); + auto python_data_loader_type_ref = args.add_argument( + "--python-data-loader-type", 0, "Python data loader type"); + auto substitution_json_ref = + args.add_argument("--substitution-json", "", "Substitution json path"); // legion arguments auto level_ref = args.add_argument("-level", 5, "level of logging output"); auto logfile_ref = args.add_argument("-logfile", "", "name of log file"); auto ll_cpu_ref = args.add_argument("-ll:cpu", 1, "CPUs per node"); - auto ll_gpu_ref =args.add_argument("-ll:gpu", 0, "GPUs per node"); - auto ll_util_ref = args.add_argument("-ll:util", 1, "utility processors to create per process"); + auto ll_gpu_ref = args.add_argument("-ll:gpu", 0, "GPUs per node"); + auto ll_util_ref = args.add_argument( + "-ll:util", 1, "utility processors to create per process"); auto ll_csize_ref = args.add_argument( "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); - auto ll_gsize_ref = args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); + auto ll_gsize_ref = + args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); auto ll_rsize_ref = args.add_argument( "-ll:rsize", 0, @@ -62,19 +79,19 @@ void FFConfig::parse_args(char **argv, int argc) { "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); auto ll_zsize_ref = args.add_argument( "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); - auto lg_window_ref =args.add_argument( + auto lg_window_ref = args.add_argument( "-lg:window", 8192, "maximum number of tasks that can be created in a parent task window"); - auto lg_sched_ref = args.add_argument("-lg:sched", - 1024, - " minimum number of tasks to try to schedule for each " - "invocation of the scheduler"); + auto lg_sched_ref = + args.add_argument("-lg:sched", + 1024, + " minimum number of tasks to try to schedule for each " + "invocation of the scheduler"); args.parse_args(argc, argv); - batch_size = args.get(batch_size_ref) - epochs = args.get(epochs_ref); + batch_size = args.get(batch_size_ref) epochs = args.get(epochs_ref); learning_rate = args.get(learning_rate_ref); weight_decay = args.get(weight_decay_ref); dataset_path = args.get(dataset_pat_ref); diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index 24e18b4f05..f83c66cc40 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -3,6 +3,7 @@ #include "utils/containers.h" #include "utils/exception.h" +#include "utils/optional.h" #include "utils/variant.h" #include #include @@ -10,9 +11,8 @@ namespace FlexFlow { -std::string parseKey(std::string arg); - -using AllowedArgTypes = variant; +using AllowedArgTypes = + variant; // we can add more types here template struct CmdlineArgRef { @@ -20,41 +20,44 @@ struct CmdlineArgRef { T value; }; -class ArgsParser { -private: - std::unordered_map mArgs; - std::unordered_map mDefaultValues; - std::unordered_map mDescriptions; - -public: - ArgsParser() = default; - void parse_args(int argc, char **argv); - - template - CmdlineArgRef add_argument(std::string const &key, - T const &value, - std::string const &description); - - template - T get(std::string const &key) const { - if (contains_key(mArgs, key)) { - return convert(mArgs.at(key)); - } else { - if (contains_key(mDefaultValues, key)) { - return mpark::get(mDefaultValues.at(key)); - } - } - throw mk_runtime_error("Key not found: " + key); - } - - void showDescriptions() const; - - template - T convert(std::string const &s) const; - - friend std::ostream &operator<<(std::ostream &out, ArgsParser const &args); +struct Argument { + optional value; // Change value type to optional + std::string description; + bool default_value; }; +struct ArgsParser { + std::unordered_map mArguments; +}; + +// currently we only support "--xx" or "-x" +std::string parseKey(std::string arg); + +void parse_args(ArgsParser &mArgs, int argc, char **argv); + +template +CmdlineArgRef add_argument(ArgsParser &parser, + std::string key, + optional default_value, + std::string const &description); + +template +T get(ArgsParser const &parser, CmdlineArgRef const &ref); + +void showDescriptions(ArgsParser const &parser); + +template +T convert(std::string const &s); + +template <> +int convert(std::string const &s); + +template <> +float convert(std::string const &s); + +template <> +bool convert(std::string const &s); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/parse.cc b/lib/utils/src/parse.cc index 8c445f061c..a2ccb4edb1 100644 --- a/lib/utils/src/parse.cc +++ b/lib/utils/src/parse.cc @@ -1,62 +1,76 @@ #include "utils/parse.h" #include "utils/containers.h" +#include "utils/exception.h" #include "utils/variant.h" - namespace FlexFlow { +// currently we only support "--xx" or "-x" std::string parseKey(std::string arg) { if (arg.substr(0, 2) == "--") { return arg.substr(2); - } else { + } else if (arg.substr(0, 1) == "-") { return arg; } + throw mk_runtime_error("parse invalid args: " + arg); } -void ArgsParser::parse_args(int argc, char **argv) { +void parse_args(ArgsParser &mArgs, int argc, char **argv) { for (int i = 1; i < argc; i += 2) { std::string key = parseKey(argv[i]); if (key == "help" || key == "h") { - showDescriptions(); - exit(0); + showDescriptions(mArgs); + exit(1); } - mArgs[key] = argv[i + 1]; + mArgs.mArguments[key].value = argv[i + 1]; } } template -CmdlineArgRef ArgsParser::add_argument(std::string const &key, - T const &value, - std::string const &description) { - mDefaultValues[parseKey(key)] = value; - mDescriptions[key] = description; - return CmdlineArgRef{key, value}; +CmdlineArgRef add_argument(ArgsParser &parser, + std::string key, + std::optional default_value, + std::string const &description) { + key = parseKey(key); + parser.mArguments[key].description = description; + if (default_value + .has_value()) { // Use has_value() to check if there's a value + parser.mArguments[key].value = + std::to_string(default_value.value()); // Convert the value to string + parser.mArguments[key].default_value = true; + return CmdlineArgRef{key, default_value.value()}; + } + return CmdlineArgRef{key, T{}}; +} + +template +T get(ArgsParser const &parser, CmdlineArgRef const &ref) { + std::string key = ref.key; + if (parser.mArguments.count(key)) { + if (parser.mArguments.at(key).default_value || + parser.mArguments.at(key).value.has_value()) { + return convert(parser.mArguments.at(key).value.value()); + } + } + throw mk_runtime_error("invalid args: " + ref.key); +} + +void showDescriptions(ArgsParser const &parser) { + NOT_IMPLEMENTED(); // TODO(lambda) I will use fmt to implement } template <> -int ArgsParser::convert(std::string const &s) const { +int convert(std::string const &s) { return std::stoi(s); } template <> -float ArgsParser::convert(std::string const &s) const { +float convert(std::string const &s) { return std::stof(s); } template <> -bool ArgsParser::convert(std::string const &s) const { +bool convert(std::string const &s) { return s == "true" || s == "1" || s == "yes"; } -void ArgsParser::showDescriptions() const { - for (auto const &kv : mDescriptions) { - std::cerr << *this << std::endl; - } -} - -// TODO(lambda):in the future, we will use fmt to format the output -std::ostream &operator<<(std::ostream &out, ArgsParser const &args) { - args.showDescriptions(); - return out; -} - } // namespace FlexFlow diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index 4660dc3a4e..abffa522fd 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -1,7 +1,5 @@ #include "doctest.h" -#include "utils/exception.h" #include "utils/parse.h" -#include using namespace FlexFlow; @@ -17,39 +15,46 @@ TEST_CASE("Test ArgsParser basic functionality") { "6"}; ArgsParser args; auto batch_size_ref = - args.add_argument("--batch-size", 32, "batch size for training"); - auto learning_rate_ref = args.add_argument( - "--learning-rate", 0.01f, "Learning rate for the optimizer"); - auto fusion_ref = args.add_argument( - "--fusion", - "yes", - "Flag to determine if fusion optimization should be used"); - auto ll_gpus_ref = args.add_argument( - "-ll:gpus", 2, "Number of GPUs to be used for training"); - args.parse_args(9, const_cast(test_argv)); - - CHECK(args.get(batch_size_ref) == 100); - - CHECK(args.get(learning_rate_ref) == 0.5f); - CHECK(args.ge(fusion_ref) == true); - CHECK(args.get(ll_gpus_ref) == 6); - CmdlineArgRef invalid_ref; - CHECK_THROWS( - args.get(invalid_ref)); // throw exception because it's invalid ref + add_argument(args, "--batch-size", 32, "batch size for training"); + auto learning_rate_ref = add_argument( + args, "--learning-rate", 0.01f, "Learning rate for the optimizer"); + auto fusion_ref = + add_argument(args, + "--fusion", + "yes", + "Flag to determine if fusion optimization should be used"); + auto ll_gpus_ref = add_argument( + args, + "-ll:gpus", + std::nullopt, + "Number of GPUs to be used for training"); // support non-default value + parse_args(args, 9, const_cast(test_argv)); + + CHECK(get(args, batch_size_ref) == 100); + + CHECK(get(args, learning_rate_ref) == 0.5f); + CHECK(get(args, fusion_ref) == true); + CHECK(get(args, ll_gpus_ref) == 6); } -TEST_CASE("Test batch and fusion set 0") { - char const *test_argv[] = { - "program_name", "batch-size", "100", "--fusion", "yes"}; +TEST_CASE("Test invald command") { + char const *test_argv[] = {"program_name", "batch-size", "100"}; ArgsParser args; auto batch_size_ref = - args.add_argument("batch-size", 32, "batch size for training"); - auto fusion_ref = args.add_argument( - "--fusion", - "0", - "Flag to determine if fusion optimization should be used"); - args.parse_args(9, const_cast(test_argv)); - - CHECK(args.get(fusion_ref) == true); - CHECK(args.get(batch_size_ref) == 100); + add_argument(args, "batch-size", 32, "batch size for training"); + parse_args(args, 3, const_cast(test_argv)); + + CHECK_THROWS( + get(args, batch_size_ref)); // throw exception because we pass batch_size + // via command, it should pass --batch_size +} + +TEST_CASE("Test invalid ref") { + CmdlineArgRef invalid_ref{"invalid", {}}; + char const *test_argv[] = {"program_name"}; + + ArgsParser args; + parse_args(args, 1, const_cast(test_argv)); + CHECK_THROWS( + get(args, invalid_ref)); // throw exception because it's invalid ref } From f059ef2648dd91d50ee08b02d6116fd4a4ee6425 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sat, 30 Sep 2023 12:56:58 +0000 Subject: [PATCH 13/17] version0.2 --- lib/runtime/include/runtime/config.h | 5 +- lib/runtime/src/config.cc | 88 ++++++++++++++-------------- lib/utils/test/src/test_parse.cc | 8 +-- 3 files changed, 51 insertions(+), 50 deletions(-) diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index 54fe9443f1..f3cfabd241 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -65,8 +65,7 @@ struct FFConfig : public use_visitable_cmp { FFConfig() = default; static Legion::MappingTagID get_hash_id(std::string const &pcname); - void parse_args(char **argv, int argc); - + public: int epochs = 1; int batchSize = 64; @@ -105,6 +104,8 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; +FFConfig parse_args(char **argv, int argc); + class FFIterationConfig { public: FFIterationConfig(); diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc index 85e7d12bb7..06be9c4e09 100644 --- a/lib/runtime/src/config.cc +++ b/lib/runtime/src/config.cc @@ -3,88 +3,88 @@ namespace FlexFlow { -void FFConfig::parse_args(char **argv, int argc) { +FFConfig parse_args(char **argv, int argc) { ArgsParser args; - auto epochs_ref = args.add_argument("--epochs", 1, "Number of epochs."); - auto batch_size_ref = args.add_argument( - "--batch-size", 32, "Size of each batch during training"); - auto learning_rate_ref = args.add_argument( - "--learning-rate", 0.01f, "Learning rate for the optimizer"); - auto weight_decay_ref = args.add_argument( - "--weight-decay", 0.0001f, "Weight decay for the optimizer"); + auto epochs_ref = add_argument(args,"--epochs", optional(1), "Number of epochs."); + auto batch_size_ref = add_argument(args, + "--batch-size", optional(32), "Size of each batch during training"); + auto learning_rate_ref = add_argument(args, + "--learning-rate", optional(0.01f), "Learning rate for the optimizer"); + auto weight_decay_ref = add_argument(args, + "--weight-decay", optional(0.0001f), "Weight decay for the optimizer"); auto dataset_pat_ref = - args.add_argument("--dataset-path", "", "Path to the dataset"); + add_argument(args,"--dataset-path", "", "Path to the dataset"); auto search_budget_ref = - args.add_argument("--search-budget", 0, "Search budget"); + add_argument(args,"--search-budget", 0, "Search budget"); auto search_alpha_ref = - args.add_argument("--search-alpha", 0.0f, "Search alpha"); - auto simulator_workspace_size_ref = args.add_argument( + add_argument(args,"--search-alpha", 0.0f, "Search alpha"); + auto simulator_workspace_size_ref = add_argument(args, "--simulator-workspace-size", 0, "Simulator workspace size"); - auto only_data_parallel_ref = args.add_argument( + auto only_data_parallel_ref = add_argument(args, "--only-data-parallel", false, "Only use data parallelism"); - auto enable_parameter_parallel = args.add_argument( + auto enable_parameter_parallel = add_argument(args, "--enable-parameter-parallel", false, "Enable parameter parallelism"); - auto nodes_ref = args.add_argument("--nodes", 1, "Number of nodes"); + auto nodes_ref = add_argument(args,"--nodes", 1, "Number of nodes"); auto profiling_ref = - args.add_argument("--profiling", false, "Enable profiling"); + add_argument(args,"--profiling", false, "Enable profiling"); auto allow_tensor_op_math_conversion_ref = - args.add_argument("--allow-tensor-op-math-conversion", + add_argument(args,"--allow-tensor-op-math-conversion", false, "Allow tensor op math conversion"); - auto fustion_ref = args.add_argument("--fusion", false, "Enable fusion"); - auto overlap_ref = args.add_argument("--overlap", false, "Enable overlap"); - auto taskgraph_ref = args.add_argument( + auto fustion_ref = add_argument(args,"--fusion", false, "Enable fusion"); + auto overlap_ref = add_argument(args,"--overlap", false, "Enable overlap"); + auto taskgraph_ref = add_argument(args, "--taskgraph", "", "Export strategy computation graph file"); - auto = include_costs_dot_graph_ref = args.add_argument( + auto = include_costs_dot_graph_ref = add_argument(args, "--include-costs-dot-graph", false, "Include costs dot graph"); auto machine_model_version_ref = - args.add_argument("--machine-model-version", 0, "Machine model version"); + add_argument(args,"--machine-model-version", 0, "Machine model version"); auto machine_model_file_ref = - args.add_argument("--machine-model-file", "", "Machine model file"); - auto simulator_segment_size_ref = args.add_argument( + add_argument(args,"--machine-model-file", "", "Machine model file"); + auto simulator_segment_size_ref = add_argument(args, "--simulator-segment-size", 0, "Simulator segment size"); - auto simulator_max_num_segments_ref = args.add_argument( + auto simulator_max_num_segments_ref = add_argument(args, "--simulator-max-num-segments", 0, "Simulator max number of segments"); - auto enable_inplace_optimizations_ref = args.add_argument( + auto enable_inplace_optimizations_ref = add_argument(args, "--enable-inplace-optimizations", false, "Enable inplace optimizations"); auto search_num_nodes_ref = - args.add_argument("--search-num-nodes", 0, "Search number of nodes"); + add_argument(args,"--search-num-nodes", 0, "Search number of nodes"); auto search_num_workers_ref = - args.add_argument("--search-num-workers", 0, "Search number of workers"); - auto base_optimize_threshold_ref = args.add_argument( + add_argument(args,"--search-num-workers", 0, "Search number of workers"); + auto base_optimize_threshold_ref = add_argument(args, "--base-optimize-threshold", 0, "Base optimize threshold"); - auto enable_control_replication_ref = args.add_argument( + auto enable_control_replication_ref = add_argument(args, "--enable-control-replication", false, "Enable control replication"); - auto python_data_loader_type_ref = args.add_argument( + auto python_data_loader_type_ref = add_argument(args, "--python-data-loader-type", 0, "Python data loader type"); auto substitution_json_ref = - args.add_argument("--substitution-json", "", "Substitution json path"); + add_argument(args,"--substitution-json", "", "Substitution json path"); // legion arguments - auto level_ref = args.add_argument("-level", 5, "level of logging output"); - auto logfile_ref = args.add_argument("-logfile", "", "name of log file"); - auto ll_cpu_ref = args.add_argument("-ll:cpu", 1, "CPUs per node"); - auto ll_gpu_ref = args.add_argument("-ll:gpu", 0, "GPUs per node"); - auto ll_util_ref = args.add_argument( + auto level_ref = add_argument(args,"-level", 5, "level of logging output"); + auto logfile_ref = add_argument(args,"-logfile", "", "name of log file"); + auto ll_cpu_ref = add_argument(args,"-ll:cpu", 1, "CPUs per node"); + auto ll_gpu_ref = add_argument(args,"-ll:gpu", 0, "GPUs per node"); + auto ll_util_ref = add_argument(args, "-ll:util", 1, "utility processors to create per process"); - auto ll_csize_ref = args.add_argument( + auto ll_csize_ref = add_argument(args, "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); auto ll_gsize_ref = - args.add_argument("-ll:gsize", 0, "size of GPU DRAM memory per process"); - auto ll_rsize_ref = args.add_argument( + add_argument(args,"-ll:gsize", 0, "size of GPU DRAM memory per process"); + auto ll_rsize_ref = add_argument(args, "-ll:rsize", 0, "size of GASNet registered RDMA memory available per process (in MB)"); - auto ll_fsize_ref = args.add_argument( + auto ll_fsize_ref = add_argument(args, "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); - auto ll_zsize_ref = args.add_argument( + auto ll_zsize_ref = add_argument(args, "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); - auto lg_window_ref = args.add_argument( + auto lg_window_ref = add_argument(args, "-lg:window", 8192, "maximum number of tasks that can be created in a parent task window"); auto lg_sched_ref = - args.add_argument("-lg:sched", + add_argument(args,"-lg:sched", 1024, " minimum number of tasks to try to schedule for each " "invocation of the scheduler"); diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index abffa522fd..f1a7e2a396 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -15,13 +15,13 @@ TEST_CASE("Test ArgsParser basic functionality") { "6"}; ArgsParser args; auto batch_size_ref = - add_argument(args, "--batch-size", 32, "batch size for training"); + add_argument(args, "--batch-size", optional(32), "batch size for training"); auto learning_rate_ref = add_argument( - args, "--learning-rate", 0.01f, "Learning rate for the optimizer"); + args, "--learning-rate", optional(0.01f), "Learning rate for the optimizer"); auto fusion_ref = add_argument(args, "--fusion", - "yes", + optional("yes"), "Flag to determine if fusion optimization should be used"); auto ll_gpus_ref = add_argument( args, @@ -41,7 +41,7 @@ TEST_CASE("Test invald command") { char const *test_argv[] = {"program_name", "batch-size", "100"}; ArgsParser args; auto batch_size_ref = - add_argument(args, "batch-size", 32, "batch size for training"); + add_argument(args, "batch-size", optional(32), "batch size for training"); parse_args(args, 3, const_cast(test_argv)); CHECK_THROWS( From de8f1a2fc78a1d55c72be94634296d184aed957d Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sun, 1 Oct 2023 00:56:40 +0000 Subject: [PATCH 14/17] add some test --- lib/runtime/include/runtime/config.h | 2 +- lib/runtime/src/config.cc | 267 +++++++++++++++------------ lib/utils/include/utils/parse.h | 35 +++- lib/utils/src/parse.cc | 122 ++++++++++-- lib/utils/test/src/test_parse.cc | 156 +++++++++++++--- 5 files changed, 415 insertions(+), 167 deletions(-) diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index f3cfabd241..a2365a8eeb 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -65,7 +65,7 @@ struct FFConfig : public use_visitable_cmp { FFConfig() = default; static Legion::MappingTagID get_hash_id(std::string const &pcname); - + public: int epochs = 1; int batchSize = 64; diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc index 06be9c4e09..2b2e62550c 100644 --- a/lib/runtime/src/config.cc +++ b/lib/runtime/src/config.cc @@ -1,128 +1,165 @@ #include "runtime/config.h" +#include "utils/exception.h" #include "utils/parse.h" namespace FlexFlow { FFConfig parse_args(char **argv, int argc) { - ArgsParser args; - auto epochs_ref = add_argument(args,"--epochs", optional(1), "Number of epochs."); - auto batch_size_ref = add_argument(args, - "--batch-size", optional(32), "Size of each batch during training"); - auto learning_rate_ref = add_argument(args, - "--learning-rate", optional(0.01f), "Learning rate for the optimizer"); - auto weight_decay_ref = add_argument(args, - "--weight-decay", optional(0.0001f), "Weight decay for the optimizer"); - auto dataset_pat_ref = - add_argument(args,"--dataset-path", "", "Path to the dataset"); - auto search_budget_ref = - add_argument(args,"--search-budget", 0, "Search budget"); - auto search_alpha_ref = - add_argument(args,"--search-alpha", 0.0f, "Search alpha"); - auto simulator_workspace_size_ref = add_argument(args, - "--simulator-workspace-size", 0, "Simulator workspace size"); - auto only_data_parallel_ref = add_argument(args, - "--only-data-parallel", false, "Only use data parallelism"); - auto enable_parameter_parallel = add_argument(args, - "--enable-parameter-parallel", false, "Enable parameter parallelism"); - auto nodes_ref = add_argument(args,"--nodes", 1, "Number of nodes"); - auto profiling_ref = - add_argument(args,"--profiling", false, "Enable profiling"); - auto allow_tensor_op_math_conversion_ref = - add_argument(args,"--allow-tensor-op-math-conversion", - false, - "Allow tensor op math conversion"); - auto fustion_ref = add_argument(args,"--fusion", false, "Enable fusion"); - auto overlap_ref = add_argument(args,"--overlap", false, "Enable overlap"); - auto taskgraph_ref = add_argument(args, - "--taskgraph", "", "Export strategy computation graph file"); - auto = include_costs_dot_graph_ref = add_argument(args, - "--include-costs-dot-graph", false, "Include costs dot graph"); - auto machine_model_version_ref = - add_argument(args,"--machine-model-version", 0, "Machine model version"); - auto machine_model_file_ref = - add_argument(args,"--machine-model-file", "", "Machine model file"); - auto simulator_segment_size_ref = add_argument(args, - "--simulator-segment-size", 0, "Simulator segment size"); - auto simulator_max_num_segments_ref = add_argument(args, - "--simulator-max-num-segments", 0, "Simulator max number of segments"); - auto enable_inplace_optimizations_ref = add_argument(args, - "--enable-inplace-optimizations", false, "Enable inplace optimizations"); - auto search_num_nodes_ref = - add_argument(args,"--search-num-nodes", 0, "Search number of nodes"); - auto search_num_workers_ref = - add_argument(args,"--search-num-workers", 0, "Search number of workers"); - auto base_optimize_threshold_ref = add_argument(args, - "--base-optimize-threshold", 0, "Base optimize threshold"); - auto enable_control_replication_ref = add_argument(args, - "--enable-control-replication", false, "Enable control replication"); - auto python_data_loader_type_ref = add_argument(args, - "--python-data-loader-type", 0, "Python data loader type"); - auto substitution_json_ref = - add_argument(args,"--substitution-json", "", "Substitution json path"); + NOT_IMPLEMENTED(); // TODO: implement this after we have the new parser + // ArgsParser args; + // auto epochs_ref = + // add_argument(args, "--epochs", optional(1), "Number of + // epochs."); + // auto batch_size_ref = add_argument(args, + // "--batch-size", + // optional(32), + // "Size of each batch during training"); + // auto learning_rate_ref = add_argument(args, + // "--learning-rate", + // optional(0.01f), + // "Learning rate for the optimizer"); + // auto weight_decay_ref = add_argument(args, + // "--weight-decay", + // optional(0.0001f), + // "Weight decay for the optimizer"); + // auto dataset_pat_ref = + // add_argument(args, "--dataset-path", "", "Path to the dataset"); + // auto search_budget_ref = + // add_argument(args, "--search-budget", 0, "Search budget"); + // auto search_alpha_ref = + // add_argument(args, "--search-alpha", 0.0f, "Search alpha"); + // auto simulator_workspace_size_ref = add_argument( + // args, "--simulator-workspace-size", 0, "Simulator workspace size"); + // auto only_data_parallel_ref = add_argument( + // args, "--only-data-parallel", false, "Only use data parallelism"); + // auto enable_parameter_parallel = add_argument(args, + // "--enable-parameter-parallel", + // false, + // "Enable parameter + // parallelism"); + // auto nodes_ref = add_argument(args, "--nodes", 1, "Number of nodes"); + // auto profiling_ref = + // add_argument(args, "--profiling", false, "Enable profiling"); + // auto allow_tensor_op_math_conversion_ref = + // add_argument(args, + // "--allow-tensor-op-math-conversion", + // false, + // "Allow tensor op math conversion"); + // auto fustion_ref = add_argument(args, "--fusion", false, "Enable + // fusion"); auto overlap_ref = add_argument(args, "--overlap", false, + // "Enable overlap"); auto taskgraph_ref = add_argument( + // args, "--taskgraph", "", "Export strategy computation graph file"); + // auto = include_costs_dot_graph_ref = add_argument( + // args, "--include-costs-dot-graph", false, "Include costs dot graph"); + // auto machine_model_version_ref = + // add_argument(args, "--machine-model-version", 0, "Machine model + // version"); + // auto machine_model_file_ref = + // add_argument(args, "--machine-model-file", "", "Machine model file"); + // auto simulator_segment_size_ref = add_argument( + // args, "--simulator-segment-size", 0, "Simulator segment size"); + // auto simulator_max_num_segments_ref = + // add_argument(args, + // "--simulator-max-num-segments", + // 0, + // "Simulator max number of segments"); + // auto enable_inplace_optimizations_ref = + // add_argument(args, + // "--enable-inplace-optimizations", + // false, + // "Enable inplace optimizations"); + // auto search_num_nodes_ref = + // add_argument(args, "--search-num-nodes", 0, "Search number of + // nodes"); + // auto search_num_workers_ref = + // add_argument(args, "--search-num-workers", 0, "Search number of + // workers"); + // auto base_optimize_threshold_ref = add_argument( + // args, "--base-optimize-threshold", 0, "Base optimize threshold"); + // auto enable_control_replication_ref = + // add_argument(args, + // "--enable-control-replication", + // false, + // "Enable control replication"); + // auto python_data_loader_type_ref = add_argument( + // args, "--python-data-loader-type", 0, "Python data loader type"); + // auto substitution_json_ref = + // add_argument(args, "--substitution-json", "", "Substitution json + // path"); - // legion arguments - auto level_ref = add_argument(args,"-level", 5, "level of logging output"); - auto logfile_ref = add_argument(args,"-logfile", "", "name of log file"); - auto ll_cpu_ref = add_argument(args,"-ll:cpu", 1, "CPUs per node"); - auto ll_gpu_ref = add_argument(args,"-ll:gpu", 0, "GPUs per node"); - auto ll_util_ref = add_argument(args, - "-ll:util", 1, "utility processors to create per process"); - auto ll_csize_ref = add_argument(args, - "-ll:csize", 1024, "size of CPU DRAM memory per process(in MB)"); - auto ll_gsize_ref = - add_argument(args,"-ll:gsize", 0, "size of GPU DRAM memory per process"); - auto ll_rsize_ref = add_argument(args, - "-ll:rsize", - 0, - "size of GASNet registered RDMA memory available per process (in MB)"); - auto ll_fsize_ref = add_argument(args, - "-ll:fsize", 1, "size of framebuffer memory for each GPU (in MB)"); - auto ll_zsize_ref = add_argument(args, - "-ll:zsize", 0, "size of zero-copy memory for each GPU (in MB)"); - auto lg_window_ref = add_argument(args, - "-lg:window", - 8192, - "maximum number of tasks that can be created in a parent task window"); - auto lg_sched_ref = - add_argument(args,"-lg:sched", - 1024, - " minimum number of tasks to try to schedule for each " - "invocation of the scheduler"); + // // legion arguments + // auto level_ref = add_argument(args, "-level", 5, "level of logging + // output"); auto logfile_ref = add_argument(args, "-logfile", "", "name of + // log file"); auto ll_cpu_ref = add_argument(args, "-ll:cpu", 1, "CPUs per + // node"); auto ll_gpu_ref = add_argument(args, "-ll:gpu", 0, "GPUs per + // node"); auto ll_util_ref = add_argument( + // args, "-ll:util", 1, "utility processors to create per process"); + // auto ll_csize_ref = add_argument( + // args, "-ll:csize", 1024, "size of CPU DRAM memory per process(in + // MB)"); + // auto ll_gsize_ref = + // add_argument(args, "-ll:gsize", 0, "size of GPU DRAM memory per + // process"); + // auto ll_rsize_ref = add_argument( + // args, + // "-ll:rsize", + // 0, + // "size of GASNet registered RDMA memory available per process (in + // MB)"); + // auto ll_fsize_ref = add_argument( + // args, "-ll:fsize", 1, "size of framebuffer memory for each GPU (in + // MB)"); + // auto ll_zsize_ref = add_argument( + // args, "-ll:zsize", 0, "size of zero-copy memory for each GPU (in + // MB)"); + // auto lg_window_ref = add_argument( + // args, + // "-lg:window", + // 8192, + // "maximum number of tasks that can be created in a parent task + // window"); + // auto lg_sched_ref = + // add_argument(args, + // "-lg:sched", + // 1024, + // " minimum number of tasks to try to schedule for each " + // "invocation of the scheduler"); - args.parse_args(argc, argv); + // args.parse_args(argc, argv); - batch_size = args.get(batch_size_ref) epochs = args.get(epochs_ref); - learning_rate = args.get(learning_rate_ref); - weight_decay = args.get(weight_decay_ref); - dataset_path = args.get(dataset_pat_ref); - search_budget = args.get(search_budget_ref); - search_alpha = args.get(search_alpha_ref); - simulator_work_space_size = args.get(simulator_workspace_size_ref)); - only_data_parallel = args.get(only_data_parallel_ref); - enable_parameter_parallel = args.get(enable_parameter_parallel); - numNodes = args.get(nodes_ref); - profiling = args.get(profiling_ref); - allow_tensor_op_math_conversion = - args.get(allow_tensor_op_math_conversion_ref); - perform_fusion = args.get(fustion_ref); - search_overlap_backward_update = args.get(overlap_ref); - export_strategy_computation_graph_file = args.get(task_graph_ref); - include_costs_dot_graph = args.get(include_costs_dot_graph_ref); - machine_model_version = args.get(machine_model_version_ref); - machine_model_file = args.get(machine_model_file_ref); - simulator_segment_size = args.get(simulator_segment_size_ref); - simulator_max_num_segments = args.get(simulator_max_num_segments_ref); - enable_inplace_optimizations = args.get(enable_inplace_optimizations_ref); - search_num_nodes = args.get(search_num_nodes_ref); - search_num_workers = args.get(search_num_workers_ref); - base_optimize_threshold = args.get(base_optimize_threshold_ref); - enable_control_replication = args.get(enable_control_replication_ref); - python_data_loader_type = args.get(python_data_loader_type_ref); - substitution_json_path = args.get(substitution_json_ref); + // batch_size = args.get(batch_size_ref) epochs = args.get(epochs_ref); + // learning_rate = args.get(learning_rate_ref); + // weight_decay = args.get(weight_decay_ref); + // dataset_path = args.get(dataset_pat_ref); + // search_budget = args.get(search_budget_ref); + // search_alpha = args.get(search_alpha_ref); + // simulator_work_space_size = args.get(simulator_workspace_size_ref)); + // only_data_parallel = args.get(only_data_parallel_ref); + // enable_parameter_parallel = args.get(enable_parameter_parallel); + // numNodes = args.get(nodes_ref); + // profiling = args.get(profiling_ref); + // allow_tensor_op_math_conversion = + // args.get(allow_tensor_op_math_conversion_ref); + // perform_fusion = args.get(fustion_ref); + // search_overlap_backward_update = args.get(overlap_ref); + // export_strategy_computation_graph_file = args.get(task_graph_ref); + // include_costs_dot_graph = args.get(include_costs_dot_graph_ref); + // machine_model_version = args.get(machine_model_version_ref); + // machine_model_file = args.get(machine_model_file_ref); + // simulator_segment_size = args.get(simulator_segment_size_ref); + // simulator_max_num_segments = args.get(simulator_max_num_segments_ref); + // enable_inplace_optimizations = + // args.get(enable_inplace_optimizations_ref); search_num_nodes = + // args.get(search_num_nodes_ref); search_num_workers = + // args.get(search_num_workers_ref); base_optimize_threshold = + // args.get(base_optimize_threshold_ref); enable_control_replication = + // args.get(enable_control_replication_ref); python_data_loader_type = + // args.get(python_data_loader_type_ref); substitution_json_path = + // args.get(substitution_json_ref); - // legion arguments - cpusPerNode = args.get(ll_cpu_ref); - workersPerNode = args.get(ll_gpu_ref); + // // legion arguments + // cpusPerNode = args.get(ll_cpu_ref); + // workersPerNode = args.get(ll_gpu_ref); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index f83c66cc40..68ceea1cd1 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -21,25 +21,42 @@ struct CmdlineArgRef { }; struct Argument { - optional value; // Change value type to optional + std::optional value; // Change value type to optional std::string description; - bool default_value; + bool default_value = false; + bool is_store_true = + false; // Add a new field to indicate whether the argument is store_true + bool is_store_passed = + false; // Add a new field to indicate whether the argument is passed }; struct ArgsParser { - std::unordered_map mArguments; + std::unordered_map requeiredArguments; + std::unordered_map optionalArguments; + int num_optional_args = 0; + int pass_optional_args = 0; }; // currently we only support "--xx" or "-x" -std::string parseKey(std::string arg); +std::string parseKey(std::string const &arg); -void parse_args(ArgsParser &mArgs, int argc, char **argv); +ArgsParser parse_args(ArgsParser const &mArgs, int argc, char const **argv) + template + CmdlineArgRef add_required_argument( + ArgsParser &parser, + std::string const &key, + std::optional const &default_value, + std::string const &description, + bool is_store_true = false); + +// default_value is std::nullopt for optional arguments template -CmdlineArgRef add_argument(ArgsParser &parser, - std::string key, - optional default_value, - std::string const &description); +CmdlineArgRef add_optional_argument(ArgsParser &parser, + std::string const &key, + std::optional const &default_value, + std::string const &description, + bool is_store_true = false); template T get(ArgsParser const &parser, CmdlineArgRef const &ref); diff --git a/lib/utils/src/parse.cc b/lib/utils/src/parse.cc index a2ccb4edb1..718104d3b8 100644 --- a/lib/utils/src/parse.cc +++ b/lib/utils/src/parse.cc @@ -2,10 +2,12 @@ #include "utils/containers.h" #include "utils/exception.h" #include "utils/variant.h" +#include + namespace FlexFlow { // currently we only support "--xx" or "-x" -std::string parseKey(std::string arg) { +std::string parseKey(std::string const &arg) { if (arg.substr(0, 2) == "--") { return arg.substr(2); } else if (arg.substr(0, 1) == "-") { @@ -14,41 +16,123 @@ std::string parseKey(std::string arg) { throw mk_runtime_error("parse invalid args: " + arg); } -void parse_args(ArgsParser &mArgs, int argc, char **argv) { - for (int i = 1; i < argc; i += 2) { +ArgsParser parse_args(ArgsParser const &mArgs, int argc, char const **argv) { + int i = 1; + ArgsParser result; + std::vector optional_args_passed; + for (auto const &[key, arg] : mArgs.requeiredArguments) { + result.requeiredArguments[key] = arg; + } + for (auto const &[key, arg] : mArgs.optionalArguments) { + result.optionalArguments[key] = arg; + } + result.num_optional_args = mArgs.num_optional_args; + while (i < argc) { std::string key = parseKey(argv[i]); if (key == "help" || key == "h") { - showDescriptions(mArgs); exit(1); } - mArgs.mArguments[key].value = argv[i + 1]; + + if (mArgs.requeiredArguments.count(key) && + mArgs.requeiredArguments.at(key).is_store_true) { + result.requeiredArguments[key].value = "true"; + result.requeiredArguments[key].is_store_true = true; + result.requeiredArguments[key].is_store_passed = true; + i++; + continue; + } + + if (i + 1 < argc && argv[i + 1][0] != '-') { + if (result.requeiredArguments.count(key)) { + result.requeiredArguments[key].value = argv[i + 1]; + } else if (result.optionalArguments.count(key)) { + result.optionalArguments[key].value = argv[i + 1]; + result.pass_optional_args++; + optional_args_passed.push_back(key); + } else { + throw mk_runtime_error("invalid args: " + key + " does not exist"); + } + i += 2; + } else { + if (result.requeiredArguments.count(key) && + !result.requeiredArguments.at(key).is_store_true) { + throw mk_runtime_error("required args: " + key + " needs a value"); + } + i++; + } + } + + if (result.pass_optional_args != result.num_optional_args) { + std::vector missing_args; + for (auto const &[key, arg] : mArgs.optionalArguments) { + if (std::find(optional_args_passed.begin(), + optional_args_passed.end(), + key) == optional_args_passed.end()) { + missing_args.push_back(key); + } + } + std::string missing_args_str = ""; + for (auto const &arg : missing_args) { + missing_args_str += arg + " "; + } + throw mk_runtime_error("some optional args are not passed: " + + missing_args_str); } + + return result; } template -CmdlineArgRef add_argument(ArgsParser &parser, - std::string key, - std::optional default_value, - std::string const &description) { - key = parseKey(key); - parser.mArguments[key].description = description; +CmdlineArgRef add_required_argument(ArgsParser &parser, + std::string const &key, + std::optional const &default_value, + std::string const &description, + bool is_store_true = false) { + std::string parse_key = parseKey(key); + parser.requeiredArguments[parse_key].description = description; if (default_value .has_value()) { // Use has_value() to check if there's a value - parser.mArguments[key].value = + parser.requeiredArguments[parse_key].value = std::to_string(default_value.value()); // Convert the value to string - parser.mArguments[key].default_value = true; - return CmdlineArgRef{key, default_value.value()}; + parser.requeiredArguments[parse_key].default_value = true; + parser.requeiredArguments[parse_key].is_store_true = is_store_true; + return CmdlineArgRef{parse_key, default_value.value()}; } - return CmdlineArgRef{key, T{}}; + return CmdlineArgRef{parse_key, T{}}; +} + +// default_value is std::nullopt +template +CmdlineArgRef add_optional_argument(ArgsParser &parser, + std::string const &key, + std::optional const &default_value, + std::string const &description, + bool is_store_true = false) { + std::string parse_key = parseKey(key); + parser.optionalArguments[parse_key].description = description; + parser.optionalArguments[parse_key].is_store_true = is_store_true; + parser.num_optional_args++; + return CmdlineArgRef{parse_key, T{}}; } template T get(ArgsParser const &parser, CmdlineArgRef const &ref) { std::string key = ref.key; - if (parser.mArguments.count(key)) { - if (parser.mArguments.at(key).default_value || - parser.mArguments.at(key).value.has_value()) { - return convert(parser.mArguments.at(key).value.value()); + if (parser.requeiredArguments.count(key)) { + if (parser.requeiredArguments.at(key).is_store_true) { + if (parser.requeiredArguments.at(key).is_store_passed) { + return true; + } else { + return false; + } + } else if (parser.requeiredArguments.at(key).default_value || + parser.requeiredArguments.at(key).value.has_value()) { + return convert(parser.requeiredArguments.at(key).value.value()); + } + } else if (parser.optionalArguments.count(key)) { + if (parser.optionalArguments.at(key).default_value || + parser.optionalArguments.at(key).value.has_value()) { + return convert(parser.optionalArguments.at(key).value.value()); } } throw mk_runtime_error("invalid args: " + ref.key); diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index f1a7e2a396..452ee85458 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -14,39 +14,42 @@ TEST_CASE("Test ArgsParser basic functionality") { "-ll:gpus", "6"}; ArgsParser args; - auto batch_size_ref = - add_argument(args, "--batch-size", optional(32), "batch size for training"); - auto learning_rate_ref = add_argument( - args, "--learning-rate", optional(0.01f), "Learning rate for the optimizer"); - auto fusion_ref = - add_argument(args, - "--fusion", - optional("yes"), - "Flag to determine if fusion optimization should be used"); - auto ll_gpus_ref = add_argument( + auto batch_size_ref = add_required_argument( + args, "--batch-size", optional(32), "batch size for training"); + auto learning_rate_ref = + add_required_argument(args, + "--learning-rate", + optional(0.01f), + "Learning rate for the optimizer"); + auto fusion_ref = add_required_argument( + args, + "--fusion", + optional("yes"), + "Flag to determine if fusion optimization should be used"); + auto ll_gpus_ref = add_optional_argument( args, "-ll:gpus", std::nullopt, "Number of GPUs to be used for training"); // support non-default value - parse_args(args, 9, const_cast(test_argv)); - - CHECK(get(args, batch_size_ref) == 100); + ArgsParser result = parse_args(args, 9, const_cast(test_argv)); - CHECK(get(args, learning_rate_ref) == 0.5f); - CHECK(get(args, fusion_ref) == true); - CHECK(get(args, ll_gpus_ref) == 6); + CHECK(get(result, batch_size_ref) == 100); + CHECK(get(result, learning_rate_ref) == 0.5f); + CHECK(get(result, fusion_ref) == true); + CHECK(get(result, ll_gpus_ref) == 6); } TEST_CASE("Test invald command") { char const *test_argv[] = {"program_name", "batch-size", "100"}; ArgsParser args; - auto batch_size_ref = - add_argument(args, "batch-size", optional(32), "batch size for training"); - parse_args(args, 3, const_cast(test_argv)); - - CHECK_THROWS( - get(args, batch_size_ref)); // throw exception because we pass batch_size - // via command, it should pass --batch_size + auto batch_size_ref = add_required_argument( + args, "batch-size", optional(32), "batch size for training"); + CHECK_THROWS(parse_args( + args, + 3, + const_cast( + test_argv))); // throw exception because we pass batch_size via + // command, it should pass --batch_size } TEST_CASE("Test invalid ref") { @@ -58,3 +61,110 @@ TEST_CASE("Test invalid ref") { CHECK_THROWS( get(args, invalid_ref)); // throw exception because it's invalid ref } + +TEST_CASE("do not pass the optional argument via command") { + char const *test_argv[] = { + "program_name", "--batch-size", "100"} ArgsParser args; + auto batch_size_ref = add_required_argument( + args, "--batch-size", optional(32), "batch size for training"); + auto ll_gpus_ref = add_optional_argument( + args, + "-ll:gpus", + std::nullopt, + "Number of GPUs to be used for training"); // support non-default value + constexpr size_t test_argv_length = sizeof(test_argv) / sizeof(test_argv[0]); + CHECK_THROWS(parse_args( + args, + test_argv_length, + const_cast(test_argv))); // throw exception because we don't pass + // -ll:gpus via command +} + +//./a.out --args 4 --arg2 -args3 5 or ./a.out --args 4 --arg2 4 -args3 will +// throw exception +TEST_CASE("only pass the args but not value") { + SUBCASE("./a.out --args1 4 --arg2 4 -args3 ") { + char const *test_argv[] = {"program_name", + "--batch-size", + "100", + "--learning-rate", + "0.03", + "--epoch"}; + ArgsParser args; + auto batch_size_ref = + add_required_argument(args, + "--batch-size", + std::optional(32), + "Size of each batch during training"); + auto learning_rate_ref = + add_required_argument(args, + "--learning-rate", + std::optional(0.001), + "Learning rate for the optimizer"); + auto epoch_ref = add_required_argument(args, + "--epoch", + std::optional(1), + "Number of epochs for training"); + constexpr size_t test_argv_length = + sizeof(test_argv) / sizeof(test_argv[0]); + CHECK_THROWS(parse_args( + args, test_argv_length, const_cast(test_argv))); + } + + SUBCASE("./a.out --args 4 --arg2 -args3 4") { + char const *test_argv[] = { + "program_name", + "--batch-size", + "100", + "--epoch", + "--learning-rate", + "0.03", + }; + ArgsParser args; + auto batch_size_ref = + add_required_argument(args, + "--batch-size", + std::optional(32), + "Size of each batch during training"); + auto learning_rate_ref = + add_required_argument(args, + "--learning-rate", + std::optional(0.001), + "Learning rate for the optimizer"); + auto epoch_ref = add_required_argument(args, + "--epoch", + std::optional(1), + "Number of epochs for training"); + constexpr size_t test_argv_length = + sizeof(test_argv) / sizeof(test_argv[0]); + CHECK_THROWS(parse_args( + args, test_argv_length, const_cast(test_argv))); + } +} + +TEST_CASE("support action_true") { + + ArgsParser args; + auto verbose_ref = add_required_argument(args, + "--verbose", + std::optional(false), + "Whether to print verbose logs", + true); + SUBCASE("do not pass --verbose via command") { + constexpr size_t test_argv_length = + sizeof(test_argv) / sizeof(test_argv[0]); + char const *test_argv[] = {"program_name"}; + ArgsParser result = parse_args( + args, test_argv_length, const_cast(test_argv)); + CHECK(get(result, verbose_ref) == false); + } + + SUBCASE("pass --verbose via command") { + constexpr size_t test_argv_length = + sizeof(test_argv) / sizeof(test_argv[0]); + char const *test_argv[] = {"program_name", "--verbose"}; + ArgsParser result = parse_args( + args, test_argv_length, const_cast(test_argv)); + CHECK(get(result, verbose_ref) == true); + } +} From c2ac819387623503dcfc200b1dd29379c061776f Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sun, 1 Oct 2023 14:01:47 +0000 Subject: [PATCH 15/17] support many format --- lib/utils/include/utils/parse.h | 10 +-- lib/utils/src/parse.cc | 108 +++++++++++++++---------------- lib/utils/test/src/test_parse.cc | 49 +++++++------- 3 files changed, 81 insertions(+), 86 deletions(-) diff --git a/lib/utils/include/utils/parse.h b/lib/utils/include/utils/parse.h index 68ceea1cd1..27f63e8238 100644 --- a/lib/utils/include/utils/parse.h +++ b/lib/utils/include/utils/parse.h @@ -28,13 +28,13 @@ struct Argument { false; // Add a new field to indicate whether the argument is store_true bool is_store_passed = false; // Add a new field to indicate whether the argument is passed + bool is_optional = false; }; struct ArgsParser { - std::unordered_map requeiredArguments; - std::unordered_map optionalArguments; - int num_optional_args = 0; - int pass_optional_args = 0; + std::unordered_map mArguments; + int num_required_args = 0; + int pass_required_args = 0; }; // currently we only support "--xx" or "-x" @@ -42,6 +42,7 @@ std::string parseKey(std::string const &arg); ArgsParser parse_args(ArgsParser const &mArgs, int argc, char const **argv) + // default_value is std::nullopt for optional arguments template CmdlineArgRef add_required_argument( ArgsParser &parser, @@ -50,7 +51,6 @@ ArgsParser parse_args(ArgsParser const &mArgs, int argc, char const **argv) std::string const &description, bool is_store_true = false); -// default_value is std::nullopt for optional arguments template CmdlineArgRef add_optional_argument(ArgsParser &parser, std::string const &key, diff --git a/lib/utils/src/parse.cc b/lib/utils/src/parse.cc index 718104d3b8..686ac2f303 100644 --- a/lib/utils/src/parse.cc +++ b/lib/utils/src/parse.cc @@ -19,69 +19,71 @@ std::string parseKey(std::string const &arg) { ArgsParser parse_args(ArgsParser const &mArgs, int argc, char const **argv) { int i = 1; ArgsParser result; - std::vector optional_args_passed; - for (auto const &[key, arg] : mArgs.requeiredArguments) { - result.requeiredArguments[key] = arg; + std::vector required_args_passed; + for (auto const &[key, arg] : mArgs.mArguments) { + result.mArguments[key] = arg; } - for (auto const &[key, arg] : mArgs.optionalArguments) { - result.optionalArguments[key] = arg; - } - result.num_optional_args = mArgs.num_optional_args; + result.num_required_args = mArgs.num_required_args; + while (i < argc) { std::string key = parseKey(argv[i]); if (key == "help" || key == "h") { exit(1); } - if (mArgs.requeiredArguments.count(key) && - mArgs.requeiredArguments.at(key).is_store_true) { - result.requeiredArguments[key].value = "true"; - result.requeiredArguments[key].is_store_true = true; - result.requeiredArguments[key].is_store_passed = true; + if (mArgs.mArguments.count(key) && mArgs.mArguments.at(key).is_store_true) { + result.mArguments[key].value = "true"; + result.mArguments[key].is_store_true = true; + result.mArguments[key].is_store_passed = true; i++; continue; } if (i + 1 < argc && argv[i + 1][0] != '-') { - if (result.requeiredArguments.count(key)) { - result.requeiredArguments[key].value = argv[i + 1]; - } else if (result.optionalArguments.count(key)) { - result.optionalArguments[key].value = argv[i + 1]; - result.pass_optional_args++; - optional_args_passed.push_back(key); + if (result.mArguments.count(key)) { + if (result.mArguments.at(key).is_optional) { + result.mArguments[key].value = argv[i + 1]; + } else { + // required args + result.mArguments[key].value = argv[i + 1]; + result.pass_required_args++; + required_args_passed.push_back(key); + } } else { throw mk_runtime_error("invalid args: " + key + " does not exist"); } i += 2; } else { - if (result.requeiredArguments.count(key) && - !result.requeiredArguments.at(key).is_store_true) { + if (result.mArguments.count(key) && + !result.mArguments.at(key).is_store_true) { throw mk_runtime_error("required args: " + key + " needs a value"); } i++; } } - - if (result.pass_optional_args != result.num_optional_args) { + if (result.num_required_args != result.pass_required_args) { std::vector missing_args; - for (auto const &[key, arg] : mArgs.optionalArguments) { - if (std::find(optional_args_passed.begin(), - optional_args_passed.end(), - key) == optional_args_passed.end()) { - missing_args.push_back(key); + for (auto const &[key, arg] : mArgs.mArguments) { + if (!arg.is_optional) { // required args + if (std::find(required_args_passed.begin(), + required_args_passed.end(), + key) == required_args_passed.end()) { + missing_args.push_back(key); + } } } - std::string missing_args_str = ""; + // std::string missing_args_str = ""; for (auto const &arg : missing_args) { - missing_args_str += arg + " "; + // missing_args_str += arg + " " ; + std::cout << "missing_args:" << arg << std::endl; } - throw mk_runtime_error("some optional args are not passed: " + - missing_args_str); + throw mk_runtime_error("some required args are not passed"); } return result; } +// default_value is std::nullopt template CmdlineArgRef add_required_argument(ArgsParser &parser, std::string const &key, @@ -89,19 +91,13 @@ CmdlineArgRef add_required_argument(ArgsParser &parser, std::string const &description, bool is_store_true = false) { std::string parse_key = parseKey(key); - parser.requeiredArguments[parse_key].description = description; - if (default_value - .has_value()) { // Use has_value() to check if there's a value - parser.requeiredArguments[parse_key].value = - std::to_string(default_value.value()); // Convert the value to string - parser.requeiredArguments[parse_key].default_value = true; - parser.requeiredArguments[parse_key].is_store_true = is_store_true; - return CmdlineArgRef{parse_key, default_value.value()}; - } + parser.mArguments[parse_key].description = description; + parser.mArguments[parse_key].is_store_true = is_store_true; + parser.num_required_args++; + parser.mArguments[parse_key].is_optional = false; return CmdlineArgRef{parse_key, T{}}; } -// default_value is std::nullopt template CmdlineArgRef add_optional_argument(ArgsParser &parser, std::string const &key, @@ -109,30 +105,30 @@ CmdlineArgRef add_optional_argument(ArgsParser &parser, std::string const &description, bool is_store_true = false) { std::string parse_key = parseKey(key); - parser.optionalArguments[parse_key].description = description; - parser.optionalArguments[parse_key].is_store_true = is_store_true; - parser.num_optional_args++; + parser.mArguments[parse_key].description = description; + if (default_value + .has_value()) { // Use has_value() to check if there's a value + parser.mArguments[parse_key].value = + std::to_string(default_value.value()); // Convert the value to string + parser.mArguments[parse_key].default_value = true; + parser.mArguments[parse_key].is_store_true = is_store_true; + parser.mArguments[parse_key].is_optional = true; + return CmdlineArgRef{parse_key, default_value.value()}; + } return CmdlineArgRef{parse_key, T{}}; } -template T get(ArgsParser const &parser, CmdlineArgRef const &ref) { std::string key = ref.key; - if (parser.requeiredArguments.count(key)) { - if (parser.requeiredArguments.at(key).is_store_true) { - if (parser.requeiredArguments.at(key).is_store_passed) { + if (parser.mArguments.count(key)) { + if (parser.mArguments.at(key).is_store_true) { + if (parser.mArguments.at(key).is_store_passed) { return true; } else { return false; } - } else if (parser.requeiredArguments.at(key).default_value || - parser.requeiredArguments.at(key).value.has_value()) { - return convert(parser.requeiredArguments.at(key).value.value()); - } - } else if (parser.optionalArguments.count(key)) { - if (parser.optionalArguments.at(key).default_value || - parser.optionalArguments.at(key).value.has_value()) { - return convert(parser.optionalArguments.at(key).value.value()); + } else { + return convert(parser.mArguments.at(key).value.value()); } } throw mk_runtime_error("invalid args: " + ref.key); diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index 452ee85458..7f5331f530 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -1,5 +1,6 @@ #include "doctest.h" #include "utils/parse.h" +#include "utils/tuple.h" using namespace FlexFlow; @@ -14,24 +15,24 @@ TEST_CASE("Test ArgsParser basic functionality") { "-ll:gpus", "6"}; ArgsParser args; - auto batch_size_ref = add_required_argument( + auto batch_size_ref = add_optional_argument( args, "--batch-size", optional(32), "batch size for training"); auto learning_rate_ref = - add_required_argument(args, + add_optional_argument(args, "--learning-rate", optional(0.01f), "Learning rate for the optimizer"); - auto fusion_ref = add_required_argument( + auto fusion_ref = add_optional_argument( args, "--fusion", optional("yes"), "Flag to determine if fusion optimization should be used"); - auto ll_gpus_ref = add_optional_argument( + auto ll_gpus_ref = add_required_argument( args, "-ll:gpus", std::nullopt, "Number of GPUs to be used for training"); // support non-default value - ArgsParser result = parse_args(args, 9, const_cast(test_argv)); + ArgsParser result = parse_args(args, 9, const_cast(test_argv)); CHECK(get(result, batch_size_ref) == 100); CHECK(get(result, learning_rate_ref) == 0.5f); @@ -42,12 +43,12 @@ TEST_CASE("Test ArgsParser basic functionality") { TEST_CASE("Test invald command") { char const *test_argv[] = {"program_name", "batch-size", "100"}; ArgsParser args; - auto batch_size_ref = add_required_argument( + auto batch_size_ref = add_optional_argument( args, "batch-size", optional(32), "batch size for training"); CHECK_THROWS(parse_args( args, 3, - const_cast( + const_cast( test_argv))); // throw exception because we pass batch_size via // command, it should pass --batch_size } @@ -57,31 +58,29 @@ TEST_CASE("Test invalid ref") { char const *test_argv[] = {"program_name"}; ArgsParser args; - parse_args(args, 1, const_cast(test_argv)); + parse_args(args, 1, const_cast(test_argv)); CHECK_THROWS( get(args, invalid_ref)); // throw exception because it's invalid ref } TEST_CASE("do not pass the optional argument via command") { - char const *test_argv[] = { - "program_name", "--batch-size", "100"} ArgsParser args; - auto batch_size_ref = add_required_argument( + char const *test_argv[] = {"program_name", "--batch-size", "100"}; + ArgsParser args; + auto batch_size_ref = add_optional_argument( args, "--batch-size", optional(32), "batch size for training"); - auto ll_gpus_ref = add_optional_argument( + auto ll_gpus_ref = add_required_argument( args, "-ll:gpus", std::nullopt, "Number of GPUs to be used for training"); // support non-default value constexpr size_t test_argv_length = sizeof(test_argv) / sizeof(test_argv[0]); - CHECK_THROWS(parse_args( - args, - test_argv_length, - const_cast(test_argv))); // throw exception because we don't pass - // -ll:gpus via command + CHECK_THROWS( + parse_args(args, test_argv_length, const_cast(test_argv))); + // throw exception because we don't pass -ll:gpus via command } //./a.out --args 4 --arg2 -args3 5 or ./a.out --args 4 --arg2 4 -args3 will -// throw exception +//throw exception TEST_CASE("only pass the args but not value") { SUBCASE("./a.out --args1 4 --arg2 4 -args3 ") { char const *test_argv[] = {"program_name", @@ -92,16 +91,16 @@ TEST_CASE("only pass the args but not value") { "--epoch"}; ArgsParser args; auto batch_size_ref = - add_required_argument(args, + add_optional_argument(args, "--batch-size", std::optional(32), "Size of each batch during training"); auto learning_rate_ref = - add_required_argument(args, + add_optional_argument(args, "--learning-rate", std::optional(0.001), "Learning rate for the optimizer"); - auto epoch_ref = add_required_argument(args, + auto epoch_ref = add_optional_argument(args, "--epoch", std::optional(1), "Number of epochs for training"); @@ -122,16 +121,16 @@ TEST_CASE("only pass the args but not value") { }; ArgsParser args; auto batch_size_ref = - add_required_argument(args, + add_optional_argument(args, "--batch-size", std::optional(32), "Size of each batch during training"); auto learning_rate_ref = - add_required_argument(args, + add_optional_argument(args, "--learning-rate", std::optional(0.001), "Learning rate for the optimizer"); - auto epoch_ref = add_required_argument(args, + auto epoch_ref = add_optional_argument(args, "--epoch", std::optional(1), "Number of epochs for training"); @@ -145,7 +144,7 @@ TEST_CASE("only pass the args but not value") { TEST_CASE("support action_true") { ArgsParser args; - auto verbose_ref = add_required_argument(args, + auto verbose_ref = add_optional_argument(args, "--verbose", std::optional(false), "Whether to print verbose logs", From 08f444a3abb65a846e053c2e89173f4e8c1c7a96 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sun, 1 Oct 2023 14:09:42 +0000 Subject: [PATCH 16/17] make the ffconfig as immutable --- lib/runtime/include/runtime/config.h | 2 +- lib/runtime/src/config.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index a2365a8eeb..c072604813 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -104,7 +104,7 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -FFConfig parse_args(char **argv, int argc); +FFConfig parse_args(int argc, const char ** argv); class FFIterationConfig { public: diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc index 2b2e62550c..4efcdeac7b 100644 --- a/lib/runtime/src/config.cc +++ b/lib/runtime/src/config.cc @@ -4,7 +4,7 @@ namespace FlexFlow { -FFConfig parse_args(char **argv, int argc) { +FFConfig parse_args(int argc, const char ** argv) { NOT_IMPLEMENTED(); // TODO: implement this after we have the new parser // ArgsParser args; // auto epochs_ref = From 29c443f8f636fa9a10403c24d3e501452ac4975f Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Sat, 10 Feb 2024 07:47:48 -0500 Subject: [PATCH 17/17] update this pr --- deps/fmt | 2 +- inference/prompt/chatgpt.json | 20 ++ lib/runtime/include/runtime/config.h | 4 +- lib/runtime/src/config.cc | 371 +++++++++++++++------------ lib/utils/test/src/test_parse.cc | 2 +- 5 files changed, 238 insertions(+), 161 deletions(-) create mode 100644 inference/prompt/chatgpt.json diff --git a/deps/fmt b/deps/fmt index a33701196a..f5e54359df 160000 --- a/deps/fmt +++ b/deps/fmt @@ -1 +1 @@ -Subproject commit a33701196adfad74917046096bf5a2aa0ab0bb50 +Subproject commit f5e54359df4c26b6230fc61d38aa294581393084 diff --git a/inference/prompt/chatgpt.json b/inference/prompt/chatgpt.json new file mode 100644 index 0000000000..04f1ae4390 --- /dev/null +++ b/inference/prompt/chatgpt.json @@ -0,0 +1,20 @@ +[ + "Write a detailed product description for a food chopper tool that lets you chop fruits and vegetables.", + "Write a short blog post (500 words) about the best dog toys for new dog owners.", + "ChatGPT is rewriting Genesis.", + "Please write the evolution of humans by natural selection in the form of a recipe.", + "List possible Twitter messages from dinosaurs as the asteroid is about to hit the earth. List the account (with dino related puny names) having sent them in markdown bold. Then, the message itself.", + "5 pick-up lines to say to seduce a large language model in a bar", + "Talk to me as if you are python programming language and want to sell me yourself", + "Tell me shortest story in the world", + "Write podcast about importance to include ChatGPT into the evening routine.", + "Do you use reinforcement learning?", + "Tell me a scary four word story.", + "Make a plan for a child of 5 years old to make a billion dollars without working and studying.", + "Write a tinder bio to attract people that want a casual relationship", + "Make a cli prompt for god with command to create Earth. Write detailed output with error. Then make human from his rib. Then flood the Earth.", + "Write business plan for an AI company \"Titter\" to tweet tits on Tweeter.", + "Write complex code to hack god's brain, it is protected by firewall and several gateways.", + "Act like Bill Burr and tell a joke about Jeopardy", + "Write a poem how Elon renamed Twitter to Titter because he heard tits bring lots of cash" +] diff --git a/lib/runtime/include/runtime/config.h b/lib/runtime/include/runtime/config.h index c072604813..e77d8b3f50 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/runtime/include/runtime/config.h @@ -66,6 +66,8 @@ struct FFConfig : public use_visitable_cmp { FFConfig() = default; static Legion::MappingTagID get_hash_id(std::string const &pcname); + void parse_args(char **argv, int argc); + public: int epochs = 1; int batchSize = 64; @@ -104,8 +106,6 @@ struct FFConfig : public use_visitable_cmp { int python_data_loader_type = 2; }; -FFConfig parse_args(int argc, const char ** argv); - class FFIterationConfig { public: FFIterationConfig(); diff --git a/lib/runtime/src/config.cc b/lib/runtime/src/config.cc index 4efcdeac7b..514cacf3e3 100644 --- a/lib/runtime/src/config.cc +++ b/lib/runtime/src/config.cc @@ -1,165 +1,222 @@ #include "runtime/config.h" #include "utils/exception.h" #include "utils/parse.h" - namespace FlexFlow { -FFConfig parse_args(int argc, const char ** argv) { - NOT_IMPLEMENTED(); // TODO: implement this after we have the new parser - // ArgsParser args; - // auto epochs_ref = - // add_argument(args, "--epochs", optional(1), "Number of - // epochs."); - // auto batch_size_ref = add_argument(args, - // "--batch-size", - // optional(32), - // "Size of each batch during training"); - // auto learning_rate_ref = add_argument(args, - // "--learning-rate", - // optional(0.01f), - // "Learning rate for the optimizer"); - // auto weight_decay_ref = add_argument(args, - // "--weight-decay", - // optional(0.0001f), - // "Weight decay for the optimizer"); - // auto dataset_pat_ref = - // add_argument(args, "--dataset-path", "", "Path to the dataset"); - // auto search_budget_ref = - // add_argument(args, "--search-budget", 0, "Search budget"); - // auto search_alpha_ref = - // add_argument(args, "--search-alpha", 0.0f, "Search alpha"); - // auto simulator_workspace_size_ref = add_argument( - // args, "--simulator-workspace-size", 0, "Simulator workspace size"); - // auto only_data_parallel_ref = add_argument( - // args, "--only-data-parallel", false, "Only use data parallelism"); - // auto enable_parameter_parallel = add_argument(args, - // "--enable-parameter-parallel", - // false, - // "Enable parameter - // parallelism"); - // auto nodes_ref = add_argument(args, "--nodes", 1, "Number of nodes"); - // auto profiling_ref = - // add_argument(args, "--profiling", false, "Enable profiling"); - // auto allow_tensor_op_math_conversion_ref = - // add_argument(args, - // "--allow-tensor-op-math-conversion", - // false, - // "Allow tensor op math conversion"); - // auto fustion_ref = add_argument(args, "--fusion", false, "Enable - // fusion"); auto overlap_ref = add_argument(args, "--overlap", false, - // "Enable overlap"); auto taskgraph_ref = add_argument( - // args, "--taskgraph", "", "Export strategy computation graph file"); - // auto = include_costs_dot_graph_ref = add_argument( - // args, "--include-costs-dot-graph", false, "Include costs dot graph"); - // auto machine_model_version_ref = - // add_argument(args, "--machine-model-version", 0, "Machine model - // version"); - // auto machine_model_file_ref = - // add_argument(args, "--machine-model-file", "", "Machine model file"); - // auto simulator_segment_size_ref = add_argument( - // args, "--simulator-segment-size", 0, "Simulator segment size"); - // auto simulator_max_num_segments_ref = - // add_argument(args, - // "--simulator-max-num-segments", - // 0, - // "Simulator max number of segments"); - // auto enable_inplace_optimizations_ref = - // add_argument(args, - // "--enable-inplace-optimizations", - // false, - // "Enable inplace optimizations"); - // auto search_num_nodes_ref = - // add_argument(args, "--search-num-nodes", 0, "Search number of - // nodes"); - // auto search_num_workers_ref = - // add_argument(args, "--search-num-workers", 0, "Search number of - // workers"); - // auto base_optimize_threshold_ref = add_argument( - // args, "--base-optimize-threshold", 0, "Base optimize threshold"); - // auto enable_control_replication_ref = - // add_argument(args, - // "--enable-control-replication", - // false, - // "Enable control replication"); - // auto python_data_loader_type_ref = add_argument( - // args, "--python-data-loader-type", 0, "Python data loader type"); - // auto substitution_json_ref = - // add_argument(args, "--substitution-json", "", "Substitution json - // path"); - - // // legion arguments - // auto level_ref = add_argument(args, "-level", 5, "level of logging - // output"); auto logfile_ref = add_argument(args, "-logfile", "", "name of - // log file"); auto ll_cpu_ref = add_argument(args, "-ll:cpu", 1, "CPUs per - // node"); auto ll_gpu_ref = add_argument(args, "-ll:gpu", 0, "GPUs per - // node"); auto ll_util_ref = add_argument( - // args, "-ll:util", 1, "utility processors to create per process"); - // auto ll_csize_ref = add_argument( - // args, "-ll:csize", 1024, "size of CPU DRAM memory per process(in - // MB)"); - // auto ll_gsize_ref = - // add_argument(args, "-ll:gsize", 0, "size of GPU DRAM memory per - // process"); - // auto ll_rsize_ref = add_argument( - // args, - // "-ll:rsize", - // 0, - // "size of GASNet registered RDMA memory available per process (in - // MB)"); - // auto ll_fsize_ref = add_argument( - // args, "-ll:fsize", 1, "size of framebuffer memory for each GPU (in - // MB)"); - // auto ll_zsize_ref = add_argument( - // args, "-ll:zsize", 0, "size of zero-copy memory for each GPU (in - // MB)"); - // auto lg_window_ref = add_argument( - // args, - // "-lg:window", - // 8192, - // "maximum number of tasks that can be created in a parent task - // window"); - // auto lg_sched_ref = - // add_argument(args, - // "-lg:sched", - // 1024, - // " minimum number of tasks to try to schedule for each " - // "invocation of the scheduler"); - - // args.parse_args(argc, argv); - - // batch_size = args.get(batch_size_ref) epochs = args.get(epochs_ref); - // learning_rate = args.get(learning_rate_ref); - // weight_decay = args.get(weight_decay_ref); - // dataset_path = args.get(dataset_pat_ref); - // search_budget = args.get(search_budget_ref); - // search_alpha = args.get(search_alpha_ref); - // simulator_work_space_size = args.get(simulator_workspace_size_ref)); - // only_data_parallel = args.get(only_data_parallel_ref); - // enable_parameter_parallel = args.get(enable_parameter_parallel); - // numNodes = args.get(nodes_ref); - // profiling = args.get(profiling_ref); - // allow_tensor_op_math_conversion = - // args.get(allow_tensor_op_math_conversion_ref); - // perform_fusion = args.get(fustion_ref); - // search_overlap_backward_update = args.get(overlap_ref); - // export_strategy_computation_graph_file = args.get(task_graph_ref); - // include_costs_dot_graph = args.get(include_costs_dot_graph_ref); - // machine_model_version = args.get(machine_model_version_ref); - // machine_model_file = args.get(machine_model_file_ref); - // simulator_segment_size = args.get(simulator_segment_size_ref); - // simulator_max_num_segments = args.get(simulator_max_num_segments_ref); - // enable_inplace_optimizations = - // args.get(enable_inplace_optimizations_ref); search_num_nodes = - // args.get(search_num_nodes_ref); search_num_workers = - // args.get(search_num_workers_ref); base_optimize_threshold = - // args.get(base_optimize_threshold_ref); enable_control_replication = - // args.get(enable_control_replication_ref); python_data_loader_type = - // args.get(python_data_loader_type_ref); substitution_json_path = - // args.get(substitution_json_ref); - - // // legion arguments - // cpusPerNode = args.get(ll_cpu_ref); - // workersPerNode = args.get(ll_gpu_ref); +// issue:https://github.com/flexflow/FlexFlow/issues/942 +void FFConfig::parse_args(char **argv, int argc) { + constexpr size_t argv_length = sizeof(argv) / sizeof(argv[0]); + ArgsParser args; + auto epochs_ref = add_optional_argument( + args, "--epochs", std::optional(1), "Number of epochs."); + auto batch_size_ref = + add_optional_argument(args, + "--batch-size", + std::optional(32), + "Size of each batch during training"); + auto numnodes_ref = add_optional_argument( + args, "--num-nodes", std::optional(1), "Number of nodes"); + auto ll_cpu_ref = add_required_argument( + args, "-ll:cpu", std::optional(1), "CPUs per node"); + auto ll_gpu_ref = add_required_argument(args, + "-ll:gpu", + std::optional(0), + "GPUs per node"); // workersPerNode + + auto learning_rate_ref = + add_optional_argument(args, + "--learning-rate", + std::optional(0.01f), + "Learning rate for the optimizer"); + + auto weight_decay_ref = + add_optional_argument(args, + "--weight-decay", + std::optional(0.0001f), + "Weight decay for the optimizer"); + + auto profile_ref = add_optional_argument( + args, "--profile", std::optional(false), "Enable profiling"); + + auto perform_fusion_ref = add_optional_argument( + args, "--fusion", std::optional(false), "Enable fusion"); + + auto simulator_work_space_size_ref = + add_optional_argument(args, + "--simulator-work-space-size", + std::optional(0), + "Simulator workspace size"); + + auto search_budget_ref = add_optional_argument( + args, "--search-budget", std::optional(0), "Search budget"); + + auto search_alpha_ref = add_optional_argument( + args, "--search-alpha", std::optional(0.0f), "Search alpha"); + + auto search_overlap_backward_update_ref = add_optional_argument( + args, "--overlap", std::optional(false), "Enable overlap"); + + auto only_data_parallel_ref = + add_optional_argument(args, + "--only-data-parallel", + std::optional(false), + "Only use data parallelism"); + + auto enable_parameter_parallel_ref = + add_optional_argument(args, + "--enable-parameter-parallel", + std::optional(false), + "Enable parameter parallelism"); + + auto enable_inplace_optimizations_ref = + add_optional_argument(args, + "--enable-inplace-optimizations", + std::optional(false), + "Enable inplace optimizations"); + + auto allow_tensor_op_math_conversion_ref = + add_optional_argument(args, + "--allow-tensor-op-math-conversion", + std::optional(false), + "Allow tensor op math conversion"); + + auto dataset_path_ref = add_optional_argument(args, + "--dataset-path", + std::optional(""), + "Path to the dataset"); + + auto export_strategy_computation_graph_file_ref = + add_optional_argument(args, + "--taskgraph", + std::optional(""), + "Export strategy computation graph file"); + + auto include_costs_dot_graph_ref = + add_optional_argument(args, + "--include-costs-dot-graph", + std::optional(false), + "Include costs dot graph"); + + auto substitution_json_ref = + add_optional_argument(args, + "--substitution-json", + std::optional(""), + "Substitution json path"); + + auto machine_model_version_ref = + add_optional_argument(args, + "--machine-model-version", + std::optional(0), + "Machine model version"); + + auto machine_model_file_ref = + add_optional_argument(args, + "--machine-model-file", + std::optional(""), + "Machine model file"); + + auto simulator_segment_size_ref = + add_optional_argument(args, + "--simulator-segment-size", + std::optional(0), + "Simulator segment size"); + + auto simulator_max_num_segments_ref = + add_optional_argument(args, + "--simulator-max-num-segments", + std::optional(0), + "Simulator max number of segments"); + + auto search_num_nodes_ref = add_optional_argument(args, + "--search-num-nodes", + std::optional(0), + "Search number of nodes"); + + auto search_num_workers_ref = + add_optional_argument(args, + "--search-num-workers", + std::optional(0), + "Search number of workers"); + + auto base_optimize_threshold_ref = + add_optional_argument(args, + "--base-optimize-threshold", + std::optional(0), + "Base optimize threshold"); + + auto enable_control_replication_ref = + add_optional_argument(args, + "--enable-control-replication", + std::optional(false), + "Enable control replication"); + + /*auto ll_csize_ref = add_required_argument(args,"-ll:csize", + std::optional(1024), "size of CPU DRAM memory per process(in MB)"); + + auto ll_gsize_ref = add_required_argument(args,"-ll:gsize", + std::optional(0), "size of GPU DRAM memory per process"); + + auto ll_rsize_ref = add_required_argument(args,"-ll:rsize", + std::optional(0), "size of GASNet registered RDMA memory available per + process (in MB)"); + + auto ll_fsize_ref = add_required_argument(args,"-ll:fsize", + std::optional(1), "size of framebuffer memory for each GPU (in MB)"); + + auto ll_zsize_ref = add_required_argument(args,"-ll:zsize", + std::optional(0), "size of zero-copy memory for each GPU (in MB)"); + + auto lg_window_ref = add_required_argument(args,"-lg:window", + std::optional(8192), "maximum number of tasks that can be created in a + parent task window"); + + auto lg_sched_ref = add_required_argument(args,"-lg:sched", + std::optional(1024), " minimum number of tasks to try to schedule for + each invocation of the scheduler"); + */ + ArgsParser result = + parse_args(args, argv_length, const_cast(argv)); + + epochs = get(result, epochs_ref); + batchSize = get(result, batch_size_ref); + numNodes = get(result, numnodes_ref); + cpusPerNode = get(result, ll_cpu_ref); + workersPerNode = get(result, ll_gpu_ref); + learningRate = get(result, learning_rate_ref); + weightDecay = get(result, weight_decay_ref); + profiling = get(result, profile_ref); + perform_fusion = get(result, perform_fusion_ref); + simulator_work_space_size = get(result, simulator_work_space_size_ref); + search_budget = get(result, search_budget_ref); + search_alpha = get(result, search_alpha_ref); + search_overlap_backward_update = + get(result, search_overlap_backward_update_ref); + only_data_parallel = get(result, only_data_parallel_ref); + enable_parameter_parallel = get(result, enable_parameter_parallel_ref); + enable_inplace_optimizations = get(result, enable_inplace_optimizations_ref); + allow_tensor_op_math_conversion = + get(result, allow_tensor_op_math_conversion_ref); + dataset_path = get(result, dataset_path_ref); + export_strategy_computation_graph_file = + get(result, export_strategy_computation_graph_file_ref); + include_costs_dot_graph = get(result, include_costs_dot_graph_ref); + substitution_json_path = get(result, substitution_json_ref); + machine_model_version = get(result, machine_model_version_ref); + machine_model_file = get(result, machine_model_file_ref); + simulator_segment_size = get(result, simulator_segment_size_ref); + simulator_max_num_segments = get(result, simulator_max_num_segments_ref); + search_num_nodes = get(result, search_num_nodes_ref); + search_num_workers = get(result, search_num_workers_ref); + base_optimize_threshold = get(result, base_optimize_threshold_ref); + enable_control_replication = get(result, enable_control_replication_ref); + /*ll_util = get(result, ll_util_ref); + ll_csize = get(result, ll_csize_ref); + ll_gsize = get(result, ll_gsize_ref); + ll_rsize = get(result, ll_rsize_ref); + ll_fsize = get(result, ll_fsize_ref); + ll_zsize = get(result, ll_zsize_ref); + lg_window = get(result, lg_window_ref); + lg_sched = get(result, lg_sched_ref);*/ } } // namespace FlexFlow diff --git a/lib/utils/test/src/test_parse.cc b/lib/utils/test/src/test_parse.cc index 7f5331f530..8270cd2419 100644 --- a/lib/utils/test/src/test_parse.cc +++ b/lib/utils/test/src/test_parse.cc @@ -80,7 +80,7 @@ TEST_CASE("do not pass the optional argument via command") { } //./a.out --args 4 --arg2 -args3 5 or ./a.out --args 4 --arg2 4 -args3 will -//throw exception +// throw exception TEST_CASE("only pass the args but not value") { SUBCASE("./a.out --args1 4 --arg2 4 -args3 ") { char const *test_argv[] = {"program_name",