Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions include/mp/proxy-io.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <assert.h>
#include <functional>
#include <optional>
#include <map>
#include <memory>
#include <sstream>
Expand Down Expand Up @@ -60,6 +61,18 @@ struct ProxyClient<Thread> : public ProxyClientBase<Thread, ::capnp::Void>
using ProxyClientBase::ProxyClientBase;
// https://stackoverflow.com/questions/22357887/comparing-two-mapiterators-why-does-it-need-the-copy-constructor-of-stdpair
ProxyClient(const ProxyClient&) = delete;
~ProxyClient();

void setCleanup(std::function<void()> cleanup);

//! Cleanup function to run when the connection is closed. If the Connection
//! gets destroyed before this ProxyClient<Thread> object, this cleanup
//! callback lets it destroy this object and remove its entry in the
//! thread's request_threads or callback_threads map (after resetting
//! m_cleanup so the destructor does not try to access it). But if this
//! object gets destroyed before the Connection, there's no need to run the
//! cleanup function and the destructor will unregister it.
std::optional<CleanupIt> m_cleanup;
};

template <>
Expand Down Expand Up @@ -503,6 +516,14 @@ void ProxyServerBase<Interface, Impl>::invokeDestroy()
m_context.cleanup.clear();
}

using ConnThreads = std::map<Connection*, ProxyClient<Thread>>;
using ConnThread = ConnThreads::iterator;

// Retrieve ProxyClient<Thread> object associated with this connection from a
// map, or create a new one and insert it into the map. Return map iterator and
// inserted bool.
std::tuple<ConnThread, bool> SetThread(ConnThreads& threads, std::mutex& mutex, Connection* connection, std::function<Thread::Client()> make_thread);

struct ThreadContext
{
//! Identifying string for debug.
Expand All @@ -517,7 +538,7 @@ struct ThreadContext
//! `callbackThread` argument it passes in the request, used by the server
//! in case it needs to make callbacks into the client that need to execute
//! while the client is waiting. This will be set to a local thread object.
std::map<Connection*, ProxyClient<Thread>> callback_threads;
ConnThreads callback_threads;

//! When client is making a request to a server, this is the `thread`
//! argument it passes in the request, used to control which thread on
Expand All @@ -526,7 +547,7 @@ struct ThreadContext
//! by makeThread. If a client call is being made from a thread currently
//! handling a server request, this will be set to the `callbackThread`
//! request thread argument passed in that request.
std::map<Connection*, ProxyClient<Thread>> request_threads;
ConnThreads request_threads;

//! Whether this thread is a capnp event loop thread. Not really used except
//! to assert false if there's an attempt to execute a blocking operation
Expand Down
78 changes: 37 additions & 41 deletions include/mp/proxy-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,36 +62,33 @@ void CustomBuildField(TypeList<>,
{
auto& connection = invoke_context.connection;
auto& thread_context = invoke_context.thread_context;
auto& request_threads = thread_context.request_threads;
auto& callback_threads = thread_context.callback_threads;

auto callback_thread = callback_threads.find(&connection);
if (callback_thread == callback_threads.end()) {
callback_thread =
callback_threads
.emplace(std::piecewise_construct, std::forward_as_tuple(&connection),
std::forward_as_tuple(
connection.m_threads.add(kj::heap<ProxyServer<Thread>>(thread_context, std::thread{})),
&connection, /* destroy_connection= */ false))
.first;
}

auto request_thread = request_threads.find(&connection);
if (request_thread == request_threads.end()) {
// This code will only run if IPC client call is being made for the
// first time on a new thread. After the first call, subsequent calls
// Create local Thread::Server object corresponding to the current thread
// and pass a Thread::Client reference to it in the Context.callbackThread
// field so the function being called can make callbacks to this thread.
// Also store the Thread::Client reference in the callback_threads map so
// future calls over this connection can reuse it.
auto [callback_thread, _]{SetThread(
thread_context.callback_threads, thread_context.waiter->m_mutex, &connection,
[&] { return connection.m_threads.add(kj::heap<ProxyServer<Thread>>(thread_context, std::thread{})); })};

// Call remote ThreadMap.makeThread function so server will create a
// dedicated worker thread to run function calls from this thread. Store the
// Thread::Client reference it returns in the request_threads map.
auto make_request_thread{[&]{
// This code will only run if an IPC client call is being made for the
// first time on this thread. After the first call, subsequent calls
// will use the existing request thread. This code will also never run at
// all if the current thread is a request thread created for a different
// IPC client, because in that case PassField code (below) will have set
// request_thread to point to the calling thread.
auto request = connection.m_thread_map.makeThreadRequest();
request.setName(thread_context.thread_name);
request_thread =
request_threads
.emplace(std::piecewise_construct, std::forward_as_tuple(&connection),
std::forward_as_tuple(request.send().getResult(), &connection, /* destroy_connection= */ false))
.first; // Nonblocking due to capnp request pipelining.
}
return request.send().getResult(); // Nonblocking due to capnp request pipelining.
}};
auto [request_thread, _1]{SetThread(
thread_context.request_threads, thread_context.waiter->m_mutex,
&connection, make_request_thread)};

auto context = output.init();
context.setThread(request_thread->second.m_client);
Expand Down Expand Up @@ -143,24 +140,23 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn&
// call. In this case, the callbackThread value should point
// to the same thread already in the map, so there is no
// need to update the map.
auto& request_threads = g_thread_context.request_threads;
auto request_thread = request_threads.find(server.m_context.connection);
if (request_thread == request_threads.end()) {
request_thread =
g_thread_context.request_threads
.emplace(std::piecewise_construct, std::forward_as_tuple(server.m_context.connection),
std::forward_as_tuple(context_arg.getCallbackThread(), server.m_context.connection,
/* destroy_connection= */ false))
.first;
} else {
// The requests_threads map already has an entry for
// this connection, so this must be a recursive call.
// Avoid modifying the map in this case by resetting the
// request_thread iterator, so the KJ_DEFER statement
// below doesn't do anything.
request_thread = request_threads.end();
}
KJ_DEFER(if (request_thread != request_threads.end()) request_threads.erase(request_thread));
auto& thread_context = g_thread_context;
auto& request_threads = thread_context.request_threads;
auto [request_thread, inserted]{SetThread(
request_threads, thread_context.waiter->m_mutex,
server.m_context.connection,
[&] { return context_arg.getCallbackThread(); })};

// If an entry was inserted into the requests_threads map,
// remove it after calling fn.invoke. If an entry was not
// inserted, one already existed, meaning this must be a
// recursive call (IPC call calling back to the caller which
// makes another IPC call), so avoid modifying the map.
auto erase_thread{inserted ? request_thread : request_threads.end()};
KJ_DEFER(if (erase_thread != request_threads.end()) {
std::unique_lock<std::mutex> lock(thread_context.waiter->m_mutex);
request_threads.erase(erase_thread);
});
fn.invoke(server_context, args...);
}
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
Expand Down
34 changes: 34 additions & 0 deletions src/mp/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,40 @@ void EventLoop::startAsyncThread(std::unique_lock<std::mutex>& lock)
}
}

std::tuple<ConnThread, bool> SetThread(ConnThreads& threads, std::mutex& mutex, Connection* connection, std::function<Thread::Client()> make_thread)
{
std::unique_lock<std::mutex> lock(mutex);
auto thread = threads.find(connection);
if (thread != threads.end()) return {thread, false};
thread = threads.emplace(
std::piecewise_construct, std::forward_as_tuple(connection),
std::forward_as_tuple(make_thread(), connection, /* destroy_connection= */ false)).first;
thread->second.setCleanup([&threads, &mutex, thread] {
// Connection is being destroyed before thread client is, so reset
// thread client m_cleanup member so thread client destructor does not
// try unregister this callback after connection is destroyed.
thread->second.m_cleanup.reset();
// Remove connection pointer about to be destroyed from the map
std::unique_lock<std::mutex> lock(mutex);
threads.erase(thread);
});
return {thread, true};
}

ProxyClient<Thread>::~ProxyClient()
{
if (m_cleanup) {
m_context.connection->removeSyncCleanup(*m_cleanup);
}
}

void ProxyClient<Thread>::setCleanup(std::function<void()> cleanup)
{
assert(cleanup);
assert(!m_cleanup);
m_cleanup = m_context.connection->addSyncCleanup(cleanup);
}

ProxyServer<Thread>::ProxyServer(ThreadContext& thread_context, std::thread&& thread)
: m_thread_context(thread_context), m_thread(std::move(thread))
{
Expand Down