Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 144 additions & 28 deletions be/src/vec/common/cow.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <boost/smart_ptr/intrusive_ref_counter.hpp>
#include <initializer_list>


/** Copy-on-write shared ptr.
* Allows to work with shared immutable objects and sometimes unshare and mutate you own unique copy.
*
Expand Down Expand Up @@ -92,36 +93,158 @@
* to use std::unique_ptr for it somehow.
*/
template <typename Derived>
class COW : public boost::intrusive_ref_counter<Derived> {
private:
class COW {
std::atomic_uint ref_counter;

protected:
COW() : ref_counter(0) {}

COW(COW const&) : ref_counter(0) {}

COW& operator=(COW const&) {
return *this;
}

unsigned int use_count() const {
return ref_counter.load();
}

void add_ref() {
++ref_counter;
}

void release_ref() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rethink here may cause mem leak? How boost::intrusive_ref_counter prevent the problem

if (--ref_counter == 0) {
delete static_cast<const Derived*>(this);
}
}

Derived* derived() { return static_cast<Derived*>(this); }

const Derived* derived() const { return static_cast<const Derived*>(this); }

template <typename T>
class IntrusivePtr : public boost::intrusive_ptr<T> {
class intrusive_ptr {
public:
using boost::intrusive_ptr<T>::intrusive_ptr;
intrusive_ptr() : t(nullptr) {}

intrusive_ptr(T* t, bool add_ref=true) : t(t) {
if (t && add_ref) ((std::remove_const_t<T>*)t)->add_ref();
}

template <typename U>
intrusive_ptr(intrusive_ptr<U> const& rhs) : t(rhs.get()) {
if (t) ((std::remove_const_t<T>*)t)->add_ref();
}

intrusive_ptr(intrusive_ptr const& rhs) : t(rhs.get()) {
if (t) ((std::remove_const_t<T>*)t)->add_ref();
}

~intrusive_ptr() {
if (t) ((std::remove_const_t<T>*)t)->release_ref();
}

template <typename U>
intrusive_ptr& operator=(intrusive_ptr<U> const& rhs) {
intrusive_ptr(rhs).swap(*this);
return *this;
}

intrusive_ptr(intrusive_ptr&& rhs) : t(rhs.t) {
rhs.t = nullptr;
}

intrusive_ptr& operator=(intrusive_ptr&& rhs) {
intrusive_ptr(static_cast<intrusive_ptr&&>(rhs)).swap(*this);
return *this;
}

template<class U> friend class intrusive_ptr;

template<class U>
intrusive_ptr(intrusive_ptr<U>&& rhs) : t(rhs.t) {
rhs.t = nullptr;
}

template<class U>
intrusive_ptr& operator=(intrusive_ptr<U>&& rhs) {
intrusive_ptr(static_cast<intrusive_ptr<U>&&>(rhs)).swap(*this);
return *this;
}

intrusive_ptr& operator=(intrusive_ptr const& rhs) {
intrusive_ptr(rhs).swap(*this);
return *this;
}

intrusive_ptr& operator=(T* rhs) {
intrusive_ptr(rhs).swap(*this);
return *this;
}

void reset() {
intrusive_ptr().swap(*this);
}

void reset(T* rhs) {
intrusive_ptr(rhs).swap(*this);
}

void reset(T* rhs, bool add_ref) {
intrusive_ptr(rhs, add_ref).swap(*this);
}

T* get() const {
return t;
}

T* detach() {
T* ret = t;
t = nullptr;
return ret;
}

void swap(intrusive_ptr& rhs) {
T* tmp = t;
t = rhs.t;
rhs.t = tmp;
}

T& operator*() const& {
return *t;
}

T& operator*() const& { return boost::intrusive_ptr<T>::operator*(); }
T&& operator*() const&& {
return const_cast<typename std::remove_const<T>::type&&>(
*boost::intrusive_ptr<T>::get());
return const_cast<std::remove_const_t<T>&&>(*t);
}

T* operator->() const {
return t;
}

operator bool() const {
return t != nullptr;
}

operator T*() const {
return t;
}

private:
T* t;
};

protected:
template <typename T>
class mutable_ptr : public IntrusivePtr<T> {
class mutable_ptr : public intrusive_ptr<T> {
private:
using Base = IntrusivePtr<T>;
using Base = intrusive_ptr<T>;

template <typename>
friend class COW;
template <typename, typename>
friend class COWHelper;
template <typename> friend class COW;
template <typename, typename> friend class COWHelper;

explicit mutable_ptr(T* ptr) : Base(ptr) {}

public:
/// Copy: not possible.
mutable_ptr(const mutable_ptr&) = delete;
Expand All @@ -144,17 +267,14 @@ class COW : public boost::intrusive_ref_counter<Derived> {

protected:
template <typename T>
class immutable_ptr : public IntrusivePtr<const T> {
class immutable_ptr : public intrusive_ptr<const T> {
private:
using Base = IntrusivePtr<const T>;
using Base = intrusive_ptr<const T>;

template <typename>
friend class COW;
template <typename, typename>
friend class COWHelper;
template <typename> friend class COW;
template <typename, typename> friend class COWHelper;

explicit immutable_ptr(const T* ptr) : Base(ptr) {}

public:
/// Copy from immutable ptr: ok.
immutable_ptr(const immutable_ptr&) = default;
Expand Down Expand Up @@ -198,8 +318,8 @@ class COW : public boost::intrusive_ref_counter<Derived> {
}

public:
Ptr get_ptr() const { return static_cast<Ptr>(derived()); }
MutablePtr get_ptr() { return static_cast<MutablePtr>(derived()); }
Ptr get_ptr() const { return Ptr(derived()); }
MutablePtr get_ptr() { return MutablePtr(derived()); }

protected:
MutablePtr shallow_mutate() const {
Expand Down Expand Up @@ -294,10 +414,6 @@ class COW : public boost::intrusive_ref_counter<Derived> {
*/
template <typename Base, typename Derived>
class COWHelper : public Base {
private:
Derived* derived() { return static_cast<Derived*>(this); }
const Derived* derived() const { return static_cast<const Derived*>(this); }

public:
using Ptr = typename Base::template immutable_ptr<Derived>;
using MutablePtr = typename Base::template mutable_ptr<Derived>;
Expand All @@ -313,7 +429,7 @@ class COWHelper : public Base {
}

typename Base::MutablePtr clone() const override {
return typename Base::MutablePtr(new Derived(*derived()));
return typename Base::MutablePtr(new Derived(static_cast<const Derived&>(*this)));
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/exec/vanalytic_eval_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ BlockRowPos VAnalyticEvalNode::_compare_row_to_find_end(int idx, BlockRowPos sta
}

//check whether need get column again, maybe same as first init
if (start_column != start_next_block_column) {
if (start_column.get() != start_next_block_column.get()) {
start_init_row_num = 0;
start.block_num = start_block_num;
start_column = _input_blocks[start.block_num].get_by_position(idx).column;
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/functions/function_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ class FunctionCast final : public IFunctionBase {
const auto& tmp_res = tmp_block.get_by_position(tmp_res_index);

/// May happen in fuzzy tests. For debug purpose.
if (!tmp_res.column) {
if (!tmp_res.column.get()) {
return Status::RuntimeError(fmt::format(
"Couldn't convert {} to {} in prepare_remove_nullable wrapper.",
block.get_by_position(arguments[0]).type->get_name(),
Expand Down