diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index 6d48b140..cc8706a7 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -255,7 +255,11 @@ struct Waiter std::function m_fn; }; - +//! Object holding network & rpc state associated with either an incoming server +//! connection, or an outgoing client connection. It must be created and destroyed +//! on the event loop thread. +//! In addition to Cap'n Proto state, it also holds lists of callbacks to run +//! when the connection is closed. class Connection { public: @@ -297,8 +301,23 @@ class Connection //! disconnect() is called. void addAsyncCleanup(std::function fn); + //! Add disconnect handler. + template + void onDisconnect(F&& f) + { + // Add disconnect handler to local TaskSet to ensure it is cancelled and + // will never after connection object is destroyed. But when disconnect + // handler fires, do not call the function f right away, instead add it + // to the EventLoop TaskSet to avoid "Promise callback destroyed itself" + // error in cases where f deletes this Connection object. + m_on_disconnect.add(m_network.onDisconnect().then( + kj::mvCapture(f, [this](F&& f) { m_loop.m_task_set->add(kj::evalLater(kj::mv(f))); }))); + } + EventLoop& m_loop; kj::Own m_stream; + LoggingErrorHandler m_error_handler{m_loop}; + kj::TaskSet m_on_disconnect{m_error_handler}; ::capnp::TwoPartyVatNetwork m_network; ::capnp::RpcSystem<::capnp::rpc::twoparty::VatId> m_rpc_system; @@ -326,6 +345,30 @@ struct ServerVatId ServerVatId() { vat_id.setSide(::capnp::rpc::twoparty::Side::SERVER); } }; +template +ProxyClientBase::ProxyClientBase(typename Interface::Client client, + Connection* connection, + bool destroy_connection) + : m_client(std::move(client)), m_connection(connection), m_destroy_connection(destroy_connection) +{ + { + std::unique_lock lock(m_connection->m_loop.m_mutex); + m_connection->m_loop.addClient(lock); + } + m_cleanup = m_connection->addSyncCleanup([this]() { + // Release client capability by move-assigning to temporary. + { + typename Interface::Client(std::move(self().m_client)); + } + { + std::unique_lock lock(m_connection->m_loop.m_mutex); + m_connection->m_loop.removeClient(lock); + } + m_connection = nullptr; + }); + self().construct(); +} + template ProxyClientBase::~ProxyClientBase() noexcept { @@ -358,6 +401,11 @@ ProxyClientBase::~ProxyClientBase() noexcept std::unique_lock lock(m_connection->m_loop.m_mutex); m_connection->m_loop.removeClient(lock); } + + if (m_destroy_connection) { + delete m_connection; + m_connection = nullptr; + } }); } } diff --git a/include/mp/proxy-types.h b/include/mp/proxy-types.h index f1f6b26a..fd7532b9 100644 --- a/include/mp/proxy-types.h +++ b/include/mp/proxy-types.h @@ -69,7 +69,7 @@ void CustomBuildField(TypeList<>, .emplace(std::piecewise_construct, std::forward_as_tuple(&connection), std::forward_as_tuple( connection.m_threads.add(kj::heap>(thread_context, std::thread{})), - connection)) + &connection, /* destroy_connection= */ false)) .first; } @@ -83,10 +83,11 @@ void CustomBuildField(TypeList<>, // request_thread to point to the calling thread. auto request = connection.m_thread_map.makeThreadRequest(); request.setName(thread_context.thread_name); - request_thread = request_threads - .emplace(std::piecewise_construct, std::forward_as_tuple(&connection), - std::forward_as_tuple(request.send().getResult(), connection)) - .first; // Nonblocking due to capnp request pipelining. + request_thread = + request_threads + .emplace(std::piecewise_construct, std::forward_as_tuple(&connection), + std::forward_as_tuple(request.send().getResult(), &connection, /* destroy_connection= */ false)) + .first; // Nonblocking due to capnp request pipelining. } auto context = output.init(); @@ -126,7 +127,8 @@ auto PassField(TypeList<>, ServerContext& server_context, const Fn& fn, const Ar request_thread = g_thread_context.request_threads .emplace(std::piecewise_construct, std::forward_as_tuple(server.m_connection), - std::forward_as_tuple(context_arg.getCallbackThread(), *server.m_connection)) + std::forward_as_tuple(context_arg.getCallbackThread(), server.m_connection, + /* destroy_connection= */ false)) .first; } else { // If recursive call, avoid remove request_threads map @@ -170,29 +172,6 @@ auto PassField(TypeList<>, ServerContext& server_context, const Fn& fn, const Ar kj::mv(future.promise)); } - -template -ProxyClientBase::ProxyClientBase(typename Interface::Client client, Connection& connection) - : m_client(std::move(client)), m_connection(&connection) -{ - { - std::unique_lock lock(m_connection->m_loop.m_mutex); - m_connection->m_loop.addClient(lock); - } - m_cleanup = m_connection->addSyncCleanup([this]() { - // Release client capability by move-assigning to temporary. - { - typename Interface::Client(std::move(self().m_client)); - } - { - std::unique_lock lock(m_connection->m_loop.m_mutex); - m_connection->m_loop.removeClient(lock); - } - m_connection = nullptr; - }); - self().construct(); -} - template class Emplace { @@ -449,7 +428,8 @@ void ReadFieldUpdate(TypeList, template std::unique_ptr MakeProxyClient(InvokeContext& context, typename Interface::Client&& client) { - return std::make_unique>(std::move(client), context.connection); + return std::make_unique>( + std::move(client), &context.connection, /* destroy_connection= */ false); } template @@ -491,7 +471,8 @@ void ReadFieldNew(TypeList>, { if (input.has()) { using Interface = typename Decay::Calls; - auto client = std::make_shared>(input.get(), invoke_context.connection); + auto client = std::make_shared>( + input.get(), &invoke_context.connection, /* destroy_connection= */ false); emplace(ProxyCallFn{std::move(client)}); } }; diff --git a/include/mp/proxy.h b/include/mp/proxy.h index c69a6a7b..62f267ca 100644 --- a/include/mp/proxy.h +++ b/include/mp/proxy.h @@ -48,7 +48,7 @@ class ProxyClientBase : public Impl_ using Interface = Interface_; using Impl = Impl_; - ProxyClientBase(typename Interface::Client client, Connection& connection); + ProxyClientBase(typename Interface::Client client, Connection* connection, bool destroy_connection); ~ProxyClientBase() noexcept; // Methods called during client construction/destruction that can optionally @@ -60,7 +60,9 @@ class ProxyClientBase : public Impl_ typename Interface::Client m_client; Connection* m_connection; - CleanupIt m_cleanup; + bool m_destroy_connection; + CleanupIt m_cleanup; //!< Pointer to self-cleanup callback registered to handle connection object getting destroyed + //!< before this client object. }; //! Customizable (through template specialization) base class used in generated ProxyClient implementations from diff --git a/src/mp/test/test.cpp b/src/mp/test/test.cpp index 7ecb31a9..3a7d34cc 100644 --- a/src/mp/test/test.cpp +++ b/src/mp/test/test.cpp @@ -26,7 +26,7 @@ KJ_TEST("Call FooInterface methods") auto connection_client = std::make_unique(loop, kj::mv(pipe.ends[0]), true); auto foo_client = std::make_unique>( connection_client->m_rpc_system.bootstrap(ServerVatId().vat_id).castAs(), - *connection_client); + connection_client.get(), /* destroy_connection= */ false); foo_promise.set_value(std::move(foo_client)); disconnect_client = [&] { loop.sync([&] { connection_client.reset(); }); }; @@ -34,7 +34,7 @@ KJ_TEST("Call FooInterface methods") auto foo_server = kj::heap>(new FooImplementation, true, connection); return capnp::Capability::Client(kj::mv(foo_server)); }); - loop.m_task_set->add(connection_server->m_network.onDisconnect().then([&] { connection_server.reset(); })); + connection_server->onDisconnect([&] { connection_server.reset(); }); loop.loop(); });