Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions lib/utils/include/utils/graph/cow_ptr_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@

namespace FlexFlow {

template <typename T> struct cow_ptr_t {
cow_ptr_t() { this->set_unique(nullptr); }
cow_ptr_t(std::shared_ptr<T> ptr) { this->set_shared(std::move(ptr)); }
cow_ptr_t(std::unique_ptr<T> ptr) { this->set_unique(std::move(ptr)); }
cow_ptr_t(T const &val) { this->set_unique(make_unique<T>(val)); }
template <typename T>
struct cow_ptr_t {
cow_ptr_t() {
this->set_unique(nullptr);
}
cow_ptr_t(std::shared_ptr<T> ptr) {
this->set_shared(std::move(ptr));
}
cow_ptr_t(std::unique_ptr<T> ptr) {
this->set_unique(std::move(ptr));
}
cow_ptr_t(T const &val) {
this->set_unique(make_unique<T>(val));
}
cow_ptr_t(cow_ptr_t const &other) {
this->set_shared(other.get_shared_ptr());
}
Expand All @@ -24,7 +33,9 @@ template <typename T> struct cow_ptr_t {
using shared_t = std::shared_ptr<T const>;
using unique_t = std::unique_ptr<T>;

T const *get() const { return &this->ref(); }
T const *get() const {
return &this->ref();
}

T const &ref() const {
if (this->has_unique_access()) {
Expand All @@ -34,26 +45,35 @@ template <typename T> struct cow_ptr_t {
}
}

T const &operator*() const { return this->get(); }
T const &operator*() const {
return this->get();
}

T const *operator->() const { return &this->get(); }
T const *operator->() const {
return this->get();
}

std::shared_ptr<T const> get_shared_ptr() const {
if (this->has_unique_access()) {
this->set_shared(shared_ptr(this->get_unique()));
this->set_shared(shared_t(this->get_unique()));
}
return this->get_shared();
}

T *mutable_ptr() const {
if (this->has_unique_access()) {
return *this->get_unique();
return this->get_unique().get();
} else {
this->set_unique(unique_t(this->get_shared()->clone()));
}
auto shared = this->get_shared();
this->set_unique(unique_t(shared->clone()));
auto ptr = mpark::get_if<unique_t>(&this->ptr);
return ptr->get();
}
}

T &mutable_ref() const { return *this->mutable_ptr(); }
T &mutable_ref() const {
return *this->mutable_ptr();
}

bool has_unique_access() const {
return holds_alternative<unique_t>(this->ptr);
Expand All @@ -74,23 +94,29 @@ template <typename T> struct cow_ptr_t {
}

private:
void set_shared(shared_t ptr) {
this->ptr = variant<shared_t>(std::move(ptr));
void set_shared(shared_t ptr) const {
this->ptr =
variant<std::unique_ptr<T>, std::shared_ptr<T const>>(std::move(ptr));
}

void set_unique(unique_t ptr) {
this->ptr = variant<unique_t>(std::move(ptr));
void set_unique(std::unique_ptr<T> ptr) const {
this->ptr =
variant<std::unique_ptr<T>, std::shared_ptr<T const>>(std::move(ptr));
}

std::unique_ptr<T> &get_unique() const { return get<unique_t>(this->ptr); }
std::unique_ptr<T> get_unique() const {
auto ptr = mpark::get_if<unique_t>(&this->ptr);
return std::move(*ptr);
}

std::shared_ptr<T const> &get_shared() const {
return get<shared_t>(this->ptr);
std::shared_ptr<T const> get_shared() const {
auto ptr = mpark::get_if<shared_t>(&this->ptr);
return *ptr;
}

mutable variant<std::unique_ptr<T>, std::shared_ptr<T const>> ptr;
};

} // namespace FlexFlow

#endif
#endif