From 2878139737356591ef85a65648b11f41aedfe895 Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Thu, 8 Jun 2023 08:55:28 +0000 Subject: [PATCH 1/2] add cow_ptr_t.h --- lib/utils/include/utils/graph/cow_ptr_t.h | 68 ++++++++++++++++------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 9ebe396778..3974fcf468 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -8,11 +8,20 @@ namespace FlexFlow { -template struct cow_ptr_t { - cow_ptr_t() { this->set_unique(nullptr); } - cow_ptr_t(std::shared_ptr ptr) { this->set_shared(std::move(ptr)); } - cow_ptr_t(std::unique_ptr ptr) { this->set_unique(std::move(ptr)); } - cow_ptr_t(T const &val) { this->set_unique(make_unique(val)); } +template +struct cow_ptr_t { + cow_ptr_t() { + this->set_unique(nullptr); + } + cow_ptr_t(std::shared_ptr ptr) { + this->set_shared(std::move(ptr)); + } + cow_ptr_t(std::unique_ptr ptr) { + this->set_unique(std::move(ptr)); + } + cow_ptr_t(T const &val) { + this->set_unique(make_unique(val)); + } cow_ptr_t(cow_ptr_t const &other) { this->set_shared(other.get_shared_ptr()); } @@ -24,7 +33,9 @@ template struct cow_ptr_t { using shared_t = std::shared_ptr; using unique_t = std::unique_ptr; - T const *get() const { return &this->ref(); } + T const *get() const { + return &this->ref(); + } T const &ref() const { if (this->has_unique_access()) { @@ -34,26 +45,37 @@ template struct cow_ptr_t { } } - T const &operator*() const { return this->get(); } + T const &operator*() const { + return this->get(); + } - T const *operator->() const { return &this->get(); } + T const *operator->() const { + return this->get(); + } std::shared_ptr get_shared_ptr() const { if (this->has_unique_access()) { - this->set_shared(shared_ptr(this->get_unique())); + this->set_shared(shared_t(this->get_unique())); } return this->get_shared(); } T *mutable_ptr() const { if (this->has_unique_access()) { - return *this->get_unique(); + return this->get_unique().get(); } else { - this->set_unique(unique_t(this->get_shared()->clone())); + auto shared = this->get_shared(); + this->set_unique(unique_t(shared->clone())); + if (auto ptr = mpark::get_if(&this->ptr)) { + return ptr->get(); + } + return nullptr; } } - T &mutable_ref() const { return *this->mutable_ptr(); } + T &mutable_ref() const { + return *this->mutable_ptr(); + } bool has_unique_access() const { return holds_alternative(this->ptr); @@ -74,18 +96,24 @@ template struct cow_ptr_t { } private: - void set_shared(shared_t ptr) { - this->ptr = variant(std::move(ptr)); + void set_shared(shared_t ptr) const { + this->ptr = + variant, std::shared_ptr>(std::move(ptr)); } - void set_unique(unique_t ptr) { - this->ptr = variant(std::move(ptr)); + void set_unique(std::unique_ptr ptr) const { + this->ptr = + variant, std::shared_ptr>(std::move(ptr)); } - std::unique_ptr &get_unique() const { return get(this->ptr); } + std::unique_ptr get_unique() const { + auto ptr = mpark::get_if(&this->ptr); + return std::move(*ptr); + } - std::shared_ptr &get_shared() const { - return get(this->ptr); + std::shared_ptr get_shared() const { + auto ptr = mpark::get_if(&this->ptr); + return *ptr; } mutable variant, std::shared_ptr> ptr; @@ -93,4 +121,4 @@ template struct cow_ptr_t { } // namespace FlexFlow -#endif +#endif \ No newline at end of file From 45729668078be7869426f4eba65f8bba57c02cdd Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Thu, 8 Jun 2023 08:59:21 +0000 Subject: [PATCH 2/2] remove the nullptr for cow_ptr_t --- lib/utils/include/utils/graph/cow_ptr_t.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 3974fcf468..028b286af4 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -66,11 +66,9 @@ struct cow_ptr_t { } else { auto shared = this->get_shared(); this->set_unique(unique_t(shared->clone())); - if (auto ptr = mpark::get_if(&this->ptr)) { - return ptr->get(); - } - return nullptr; - } + auto ptr = mpark::get_if(&this->ptr); + return ptr->get(); + } } T &mutable_ref() const {