From 392d8bcc4e9fc30c5c3074b80225f144efdfeadd Mon Sep 17 00:00:00 2001 From: Victor Li Date: Tue, 21 Jan 2025 13:01:46 -0800 Subject: [PATCH 1/2] Change OpCostMetrics.memory to be a nonnegative_int --- .../op_cost_metrics.struct.toml | 3 +- .../get_optimal_machine_mapping.cc | 8 +- ...get_optimal_machine_mapping_with_memory.cc | 18 +- .../machine_mapping_result_with_memory.cc | 20 +- .../utils/nonnegative_int/nonnegative_int.h | 69 +++++ .../utils/nonnegative_int/nonnegative_int.cc | 109 ++++++++ .../utils/nonnegative_int/nonnegative_int.cc | 263 ++++++++++++++++++ 7 files changed, 466 insertions(+), 24 deletions(-) create mode 100644 lib/utils/include/utils/nonnegative_int/nonnegative_int.h create mode 100644 lib/utils/src/utils/nonnegative_int/nonnegative_int.cc create mode 100644 lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml index f137935a4d..d2ff3f42e7 100644 --- a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml @@ -7,6 +7,7 @@ features = [ ] includes = [ + "utils/nonnegative_int/nonnegative_int.h" ] [[fields]] @@ -15,4 +16,4 @@ type = "float" [[fields]] name = "memory" -type = "size_t" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index f5d5a5ee1b..ac180cd079 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -146,13 +146,13 @@ TEST_SUITE(FF_TEST_SUITE) { auto map1 = std::unordered_map{{ {map_unmapped_op_cost_estimate_key(k1, mv1), - OpCostMetrics{/*runtime=*/1.0, /*memory=*/0}}, + OpCostMetrics{/*runtime=*/1.0, /*memory=*/nonnegative_int{0}}}, {map_unmapped_op_cost_estimate_key(k2, mv1), - OpCostMetrics{/*runtime=*/2.0, /*memory=*/0}}, + OpCostMetrics{/*runtime=*/2.0, /*memory=*/nonnegative_int{0}}}, {map_unmapped_op_cost_estimate_key(k1, mv2), - OpCostMetrics{/*runtime=*/1.5, /*memory=*/0}}, + OpCostMetrics{/*runtime=*/1.5, /*memory=*/nonnegative_int{0}}}, {map_unmapped_op_cost_estimate_key(k2, mv2), - OpCostMetrics{/*runtime=*/2.5, /*memory=*/0}}, + OpCostMetrics{/*runtime=*/2.5, /*memory=*/nonnegative_int{0}}}, }}; CostEstimator cost_estimator = make_fake_cost_estimator( diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc index 8761116be2..46662d8023 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -146,10 +146,10 @@ TEST_SUITE(FF_TEST_SUITE) { CostEstimator cost_estimator = make_fake_cost_estimator( std::unordered_map{{ - {map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics{1.0, 2}}, - {map_unmapped_op_cost_estimate_key(k2, mv1), OpCostMetrics{2.0, 3}}, - {map_unmapped_op_cost_estimate_key(k1, mv2), OpCostMetrics{1.5, 1}}, - {map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics{2.5, 2}}, + {map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics{1.0, nonnegative_int{2}}}, + {map_unmapped_op_cost_estimate_key(k2, mv1), OpCostMetrics{2.0, nonnegative_int{3}}}, + {map_unmapped_op_cost_estimate_key(k1, mv2), OpCostMetrics{1.5, nonnegative_int{1}}}, + {map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics{2.5, nonnegative_int{2}}}, }}, std::unordered_map{{ {TensorSetMovement{{}}, 0.0}, @@ -183,13 +183,13 @@ TEST_SUITE(FF_TEST_SUITE) { cache, context, problem_tree, full_machine_spec, constraints); MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ MachineMappingForSingleLayer{ - OpCostMetrics{1.0, 2}, + OpCostMetrics{1.0, nonnegative_int{2}}, ParallelLayerGuidObliviousMachineMapping{{ {binary_tree_root_path(), mv1}, }}, }, MachineMappingForSingleLayer{ - OpCostMetrics{1.5, 1}, + OpCostMetrics{1.5, nonnegative_int{1}}, ParallelLayerGuidObliviousMachineMapping{{ {binary_tree_root_path(), mv2}, }}, @@ -214,7 +214,7 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingForSingleLayer{ OpCostMetrics{ /*runtime=*/1.0 + 2.0 + 0.1, - /*memory=*/2 + 3, + /*memory=*/nonnegative_int{2 + 3}, }, ParallelLayerGuidObliviousMachineMapping{{ { @@ -232,7 +232,7 @@ TEST_SUITE(FF_TEST_SUITE) { }}, }, MachineMappingForSingleLayer{ - OpCostMetrics{1.5 + 2.5 + 0.1, 1 + 2}, + OpCostMetrics{1.5 + 2.5 + 0.1, nonnegative_int{1 + 2}}, ParallelLayerGuidObliviousMachineMapping{{ { BinaryTreePath{{ @@ -266,7 +266,7 @@ TEST_SUITE(FF_TEST_SUITE) { cache, context, problem_tree, full_machine_spec, constraints); MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{MachineMappingForSingleLayer{ - OpCostMetrics{2.5, 2}, + OpCostMetrics{2.5, nonnegative_int{2}}, ParallelLayerGuidObliviousMachineMapping{{ { BinaryTreePath{{ diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc index a47d8713e9..ecfb7cfeb3 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc @@ -53,15 +53,15 @@ TEST_SUITE(FF_TEST_SUITE) { OpCostMetrics cost1 = OpCostMetrics{ /*runtime=*/2.0, - /*memory=*/2, + /*memory=*/nonnegative_int{2}, }; OpCostMetrics cost2 = OpCostMetrics{ /*runtime=*/4.0, - /*memory=*/1, + /*memory=*/nonnegative_int{1}, }; OpCostMetrics cost3 = OpCostMetrics{ /*runtime=*/2.0, - /*memory=*/3, + /*memory=*/nonnegative_int{3}, }; MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{ @@ -183,7 +183,7 @@ TEST_SUITE(FF_TEST_SUITE) { OpCostMetrics pre_cost = OpCostMetrics{ /*runtime=*/2.0, - /*memory=*/2, + /*memory=*/nonnegative_int{2}, }; MachineMappingWithMemoryResult pre = MachineMappingWithMemoryResult{{ MachineMappingForSingleLayer{ @@ -209,7 +209,7 @@ TEST_SUITE(FF_TEST_SUITE) { OpCostMetrics post_cost = OpCostMetrics{ /*runtime=*/4.0, - /*memory=*/1, + /*memory=*/nonnegative_int{1}, }; MachineMappingWithMemoryResult post = MachineMappingWithMemoryResult{{ @@ -378,7 +378,7 @@ TEST_SUITE(FF_TEST_SUITE) { OpCostMetrics lhs_cost = OpCostMetrics{ /*runtime=*/2.0, - /*memory=*/2, + /*memory=*/nonnegative_int{2}, }; MachineMappingWithMemoryResult lhs = MachineMappingWithMemoryResult{{ MachineMappingForSingleLayer{ @@ -404,7 +404,7 @@ TEST_SUITE(FF_TEST_SUITE) { OpCostMetrics rhs_cost = OpCostMetrics{ /*runtime=*/4.0, - /*memory=*/1, + /*memory=*/nonnegative_int{1}, }; MachineMappingWithMemoryResult rhs = MachineMappingWithMemoryResult{{ MachineMappingForSingleLayer{ @@ -519,15 +519,15 @@ TEST_SUITE(FF_TEST_SUITE) { OpCostMetrics cost1 = OpCostMetrics{ /*runtime=*/2.0, - /*memory=*/2, + /*memory=*/nonnegative_int{2}, }; OpCostMetrics cost2 = OpCostMetrics{ /*runtime=*/4.0, - /*memory=*/1, + /*memory=*/nonnegative_int{1}, }; OpCostMetrics cost3 = OpCostMetrics{ /*runtime=*/2.0, - /*memory=*/3, + /*memory=*/nonnegative_int{3}, }; MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{ diff --git a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h new file mode 100644 index 0000000000..0749497c56 --- /dev/null +++ b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h @@ -0,0 +1,69 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H + +#include "rapidcheck.h" + +#include +#include +#include +#include +#include + +namespace FlexFlow { +class nonnegative_int { +public: + nonnegative_int() = delete; + explicit nonnegative_int(int value); + + explicit operator int() const noexcept; + + bool operator<(nonnegative_int const &other) const; + bool operator==(nonnegative_int const &other) const; + bool operator>(nonnegative_int const &other) const; + bool operator<=(nonnegative_int const &other) const; + bool operator!=(nonnegative_int const &other) const; + bool operator>=(nonnegative_int const &other) const; + + bool operator<(int const &other) const; + bool operator==(int const &other) const; + bool operator>(int const &other) const; + bool operator<=(int const &other) const; + bool operator!=(int const &other) const; + bool operator>=(int const &other) const; + + friend bool operator<(int const &lhs, nonnegative_int const &rhs); + friend bool operator==(int const &lhs, nonnegative_int const &rhs); + friend bool operator>(int const &lhs, nonnegative_int const &rhs); + friend bool operator<=(int const &lhs, nonnegative_int const &rhs); + friend bool operator!=(int const &lhs, nonnegative_int const &rhs); + friend bool operator>=(int const &lhs, nonnegative_int const &rhs); + + nonnegative_int operator+(nonnegative_int const &other) const; + + friend std::ostream &operator<<(std::ostream &os, nonnegative_int const &n); + + friend int format_as(nonnegative_int const &); + + int get_value() const; + +private: + int value_; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::nonnegative_int> { + static ::FlexFlow::nonnegative_int from_json(json const &j); + static void to_json(json &j, ::FlexFlow::nonnegative_int t); +}; +} // namespace nlohmann + +namespace std { +template <> +struct hash<::FlexFlow::nonnegative_int> { + std::size_t operator()(FlexFlow::nonnegative_int const &n) const noexcept; +}; +} // namespace std + +#endif diff --git a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc new file mode 100644 index 0000000000..9088cc4bf9 --- /dev/null +++ b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc @@ -0,0 +1,109 @@ +#include "utils/nonnegative_int/nonnegative_int.h" + +namespace FlexFlow { + +nonnegative_int::nonnegative_int(int value) { + if (value < 0) { + throw std::invalid_argument( + "Value of nonnegative_int type must be nonnegative."); + } + this->value_ = value; +} + +nonnegative_int::operator int() const noexcept { + return this->value_; +} + +bool nonnegative_int::operator<(nonnegative_int const &other) const { + return this->value_ < other.value_; +} +bool nonnegative_int::operator==(nonnegative_int const &other) const { + return this->value_ == other.value_; +} +bool nonnegative_int::operator>(nonnegative_int const &other) const { + return this->value_ > other.value_; +} +bool nonnegative_int::operator<=(nonnegative_int const &other) const { + return this->value_ <= other.value_; +} +bool nonnegative_int::operator!=(nonnegative_int const &other) const { + return this->value_ != other.value_; +} +bool nonnegative_int::operator>=(nonnegative_int const &other) const { + return this->value_ >= other.value_; +} + +bool nonnegative_int::operator<(int const &other) const { + return this->value_ < other; +} +bool nonnegative_int::operator==(int const &other) const { + return this->value_ == other; +} +bool nonnegative_int::operator>(int const &other) const { + return this->value_ > other; +} +bool nonnegative_int::operator<=(int const &other) const { + return this->value_ <= other; +} +bool nonnegative_int::operator!=(int const &other) const { + return this->value_ != other; +} +bool nonnegative_int::operator>=(int const &other) const { + return this->value_ >= other; +} + +bool operator<(int const &lhs, nonnegative_int const &rhs) { + return lhs < rhs.value_; +} +bool operator==(int const &lhs, nonnegative_int const &rhs) { + return lhs == rhs.value_; +} +bool operator>(int const &lhs, nonnegative_int const &rhs) { + return lhs > rhs.value_; +} +bool operator<=(int const &lhs, nonnegative_int const &rhs) { + return lhs <= rhs.value_; +} +bool operator!=(int const &lhs, nonnegative_int const &rhs) { + return lhs != rhs.value_; +} +bool operator>=(int const &lhs, nonnegative_int const &rhs) { + return lhs >= rhs.value_; +} + +nonnegative_int nonnegative_int::operator+(nonnegative_int const &other) const { + return nonnegative_int{this->value_ + other.value_}; +} + +std::ostream &operator<<(std::ostream &os, nonnegative_int const &n) { + os << n.value_; + return os; +} + +int nonnegative_int::get_value() const { + return this->value_; +} + +int format_as(nonnegative_int const &x) { + return x.get_value(); +} +} // namespace FlexFlow + +namespace nlohmann { +::FlexFlow::nonnegative_int + adl_serializer<::FlexFlow::nonnegative_int>::from_json(json const &j) { + return ::FlexFlow::nonnegative_int{j.template get()}; +} + +void adl_serializer<::FlexFlow::nonnegative_int>::to_json( + json &j, ::FlexFlow::nonnegative_int t) { + j = t.get_value(); +} +} // namespace nlohmann + +namespace std { +std::size_t hash<::FlexFlow::nonnegative_int>::operator()( + FlexFlow::nonnegative_int const &n) const noexcept { + return std::hash{}(n.get_value()); +} +} // namespace std diff --git a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc new file mode 100644 index 0000000000..8b8f0d430e --- /dev/null +++ b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc @@ -0,0 +1,263 @@ +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nonnegative_int initialization") { + SUBCASE("positive int initialization") { + CHECK_NOTHROW(nonnegative_int{1}); + } + + SUBCASE("zero initialization") { + CHECK_NOTHROW(nonnegative_int{0}); + } + + SUBCASE("negative int initialization") { + CHECK_THROWS(nonnegative_int{-1}); + } + } + + TEST_CASE("nonnegative_int == comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equal") { + CHECK(nn_int_1a == nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, not equal") { + CHECK_FALSE(nn_int_1a == nn_int_2); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equal") { + CHECK(nn_int_1a == 1); + } + SUBCASE("LHS: nonnegative_int, RHS: int, not equal") { + CHECK_FALSE(nn_int_1a == 2); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equal") { + CHECK(1 == nn_int_1b); + } + SUBCASE("LHS: int, RHS: nonnegative_int, not equal") { + CHECK_FALSE(2 == nn_int_1b); + } + } + + TEST_CASE("nonnegative_int != comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equal") { + CHECK_FALSE(nn_int_1a != nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, not equal") { + CHECK(nn_int_1a != nn_int_2); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equal") { + CHECK_FALSE(nn_int_1a != 1); + } + SUBCASE("LHS: nonnegative_int, RHS: int, not equal") { + CHECK(nn_int_1a != 2); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equal") { + CHECK_FALSE(1 != nn_int_1b); + } + SUBCASE("LHS: int, RHS: nonnegative_int, not equal") { + CHECK(2 != nn_int_1b); + } + } + + TEST_CASE("nonnegative_int < comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK(nn_int_1a < nn_int_2); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK_FALSE(nn_int_1a < nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK_FALSE(nn_int_2 < nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK(nn_int_1a < 2); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK_FALSE(nn_int_1a < 1); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK_FALSE(nn_int_2 < 1); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK(1 < nn_int_2); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK_FALSE(1 < nn_int_1b); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK_FALSE(2 < nn_int_1b); + } + } + + TEST_CASE("nonnegative_int <= comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK(nn_int_1a <= nn_int_2); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK(nn_int_1a <= nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK_FALSE(nn_int_2 <= nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK(nn_int_1a <= 2); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK(nn_int_1a <= 1); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK_FALSE(nn_int_2 <= 1); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK(1 <= nn_int_2); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK(1 <= nn_int_1b); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK_FALSE(2 <= nn_int_1b); + } + } + + TEST_CASE("nonnegative_int > comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK_FALSE(nn_int_1a > nn_int_2); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK_FALSE(nn_int_1a > nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK(nn_int_2 > nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK_FALSE(nn_int_1a > 2); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK_FALSE(nn_int_1a > 1); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK(nn_int_2 > 1); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK_FALSE(1 > nn_int_2); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK_FALSE(1 > nn_int_1b); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK(2 > nn_int_1b); + } + } + + TEST_CASE("nonnegative_int >= comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK_FALSE(nn_int_1a >= nn_int_2); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK(nn_int_1a >= nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK(nn_int_2 >= nn_int_1b); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK_FALSE(nn_int_1a >= 2); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK(nn_int_1a >= 1); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK(nn_int_2 >= 1); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK_FALSE(1 >= nn_int_2); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK(1 >= nn_int_1b); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK(2 >= nn_int_1b); + } + } + + TEST_CASE("nonnegative_int + operation") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int") { + CHECK(nn_int_1a + nn_int_1b == nn_int_2); + } + } + + TEST_CASE("adl_serializer") { + SUBCASE("to_json") { + nonnegative_int input = nonnegative_int{5}; + + nlohmann::json result = input; + nlohmann::json correct = 5; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + nlohmann::json input = 5; + + nonnegative_int result = input.template get(); + nonnegative_int correct = nonnegative_int{5}; + + CHECK(result == correct); + } + } + + TEST_CASE("std::hash") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + std::hash hash_fn; + SUBCASE("Identical values have the same hash") { + CHECK(hash_fn(nn_int_1a) == hash_fn(nn_int_1b)); + } + SUBCASE("Different values have different hashes") { + CHECK(hash_fn(nn_int_1a) != hash_fn(nn_int_2)); + } + SUBCASE("Unordered set works with nonnegative_int") { + std::unordered_set<::FlexFlow::nonnegative_int> nonnegative_int_set; + nonnegative_int_set.insert(nn_int_1a); + nonnegative_int_set.insert(nn_int_1b); + nonnegative_int_set.insert(nn_int_2); + + CHECK(nonnegative_int_set.size() == 2); + } + } + + TEST_CASE("nonnegative int >> operator") { + nonnegative_int nn_int_1 = nonnegative_int{1}; + std::ostringstream oss; + oss << nn_int_1; + + CHECK(oss.str() == "1"); + } + + TEST_CASE("fmt::to_string(nonnegative_int)") { + nonnegative_int nn_int_1 = nonnegative_int{1}; + CHECK(fmt::to_string(nn_int_1) == "1"); + } +} From efc73f537b532c0f4f92553e2b45bc1f3bcc7228 Mon Sep 17 00:00:00 2001 From: Victor Li Date: Tue, 21 Jan 2025 13:01:46 -0800 Subject: [PATCH 2/2] Change OpCostMetrics.memory to be a nonnegative_int --- .../get_optimal_machine_mapping_with_memory.cc | 12 ++++++++---- .../include/utils/nonnegative_int/nonnegative_int.h | 1 + .../src/utils/nonnegative_int/nonnegative_int.cc | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc index 46662d8023..9706f1c75f 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -146,10 +146,14 @@ TEST_SUITE(FF_TEST_SUITE) { CostEstimator cost_estimator = make_fake_cost_estimator( std::unordered_map{{ - {map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics{1.0, nonnegative_int{2}}}, - {map_unmapped_op_cost_estimate_key(k2, mv1), OpCostMetrics{2.0, nonnegative_int{3}}}, - {map_unmapped_op_cost_estimate_key(k1, mv2), OpCostMetrics{1.5, nonnegative_int{1}}}, - {map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics{2.5, nonnegative_int{2}}}, + {map_unmapped_op_cost_estimate_key(k1, mv1), + OpCostMetrics{1.0, nonnegative_int{2}}}, + {map_unmapped_op_cost_estimate_key(k2, mv1), + OpCostMetrics{2.0, nonnegative_int{3}}}, + {map_unmapped_op_cost_estimate_key(k1, mv2), + OpCostMetrics{1.5, nonnegative_int{1}}}, + {map_unmapped_op_cost_estimate_key(k2, mv2), + OpCostMetrics{2.5, nonnegative_int{2}}}, }}, std::unordered_map{{ {TensorSetMovement{{}}, 0.0}, diff --git a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h index e87bdbcfd9..0749497c56 100644 --- a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h +++ b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h @@ -39,6 +39,7 @@ class nonnegative_int { friend bool operator>=(int const &lhs, nonnegative_int const &rhs); nonnegative_int operator+(nonnegative_int const &other) const; + friend std::ostream &operator<<(std::ostream &os, nonnegative_int const &n); friend int format_as(nonnegative_int const &); diff --git a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc index 8b8f0d430e..73d382d830 100644 --- a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc +++ b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc @@ -198,7 +198,7 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("nonnegative_int + operation") { + TEST_CASE("nonnegative_int + operation") { nonnegative_int nn_int_1a = nonnegative_int{1}; nonnegative_int nn_int_1b = nonnegative_int{1}; nonnegative_int nn_int_2 = nonnegative_int{2};