diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index 733217d0212..2a6f1836a31 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -23,10 +23,8 @@ #include #endif -#include #include #include -#include #include #include #include @@ -107,7 +105,7 @@ class ARROW_NO_EXPORT PlasmaBuffer : public Buffer { public: ~PlasmaBuffer(); - PlasmaBuffer(PlasmaClient::Impl* client, const ObjectID& object_id, + PlasmaBuffer(std::shared_ptr client, const ObjectID& object_id, const std::shared_ptr& buffer) : Buffer(buffer, 0, buffer->size()), client_(client), object_id_(object_id) { if (buffer->is_mutable()) { @@ -116,7 +114,7 @@ class ARROW_NO_EXPORT PlasmaBuffer : public Buffer { } private: - PlasmaClient::Impl* client_; + std::shared_ptr client_; ObjectID object_id_; }; @@ -155,7 +153,7 @@ struct ClientMmapTableEntry { int count; }; -class PlasmaClient::Impl { +class PlasmaClient::Impl : public std::enable_shared_from_this { public: Impl(); ~Impl(); @@ -558,7 +556,7 @@ Status PlasmaClient::Impl::Get(const std::vector& object_ids, int64_t timeout_ms, std::vector* out) { const auto wrap_buffer = [=](const ObjectID& object_id, const std::shared_ptr& buffer) { - return std::make_shared(this, object_id, buffer); + return std::make_shared(shared_from_this(), object_id, buffer); }; const size_t num_objects = object_ids.size(); *out = std::vector(num_objects); diff --git a/cpp/src/plasma/format/plasma.fbs b/cpp/src/plasma/format/plasma.fbs index 71e6f5c19fa..0258cdff331 100644 --- a/cpp/src/plasma/format/plasma.fbs +++ b/cpp/src/plasma/format/plasma.fbs @@ -216,7 +216,7 @@ table PlasmaStatusRequest { enum ObjectStatus:int { // Object is stored in the local Plasma Store. - Local = 1, + Local, // Object is stored on a remote Plasma store, and it is not stored on the // local Plasma Store. Remote, diff --git a/python/pyarrow/tests/test_plasma.py b/python/pyarrow/tests/test_plasma.py index 32a4ed338f5..1589c1bb304 100644 --- a/python/pyarrow/tests/test_plasma.py +++ b/python/pyarrow/tests/test_plasma.py @@ -840,3 +840,19 @@ def test_use_huge_pages(): use_hugepages=True) as (plasma_store_name, p): plasma_client = plasma.connect(plasma_store_name, "", 64) create_object(plasma_client, 100000000) + + +# This is checking to make sure plasma_clients cannot be destroyed +# before all the PlasmaBuffers that have handles to them are +# destroyed, see ARROW-2448. +@pytest.mark.plasma +def test_plasma_client_sharing(): + import pyarrow.plasma as plasma + + with start_plasma_store() as (plasma_store_name, p): + plasma_client = plasma.connect(plasma_store_name, "", 64) + object_id = plasma_client.put(np.zeros(3)) + buf = plasma_client.get(object_id) + del plasma_client + assert (buf == np.zeros(3)).all() + del buf # This segfaulted pre ARROW-2448.