From ca2cfedd2b3f2be7df1c737f36d88c726a4d9a2f Mon Sep 17 00:00:00 2001 From: Ryan Ofsky Date: Wed, 24 Jul 2024 16:16:28 -0400 Subject: [PATCH] types: Add Custom{Build,Read,Pass}Message hooks Add CustomBuildMessage, CustomReadMessage, and CustomPassMessage hook functions. These functions can be defined to use custom code to convert C++ objects to and from Cap'n Proto messages. They work similarly to existing CustomBuildField, CustomReadField, and CustomPassField hooks, except they can be defined as normal functions not template functions, so they should be easier to use and require less verbosity in most cases, although they are less flexible. The unit tests added here are new but the feature was originally implemented in https://github.com/bitcoin/bitcoin/pull/10102 but and is being ported because it's a general purpose feature. --- include/mp/proxy-types.h | 75 ++++++++++++++++++++++++++++++++++++++++ test/mp/test/foo-types.h | 39 +++++++++++++++++++++ test/mp/test/foo.capnp | 10 ++++++ test/mp/test/foo.h | 12 +++++++ test/mp/test/test.cpp | 10 ++++++ 5 files changed, 146 insertions(+) diff --git a/include/mp/proxy-types.h b/include/mp/proxy-types.h index f81cfa03..f120268e 100644 --- a/include/mp/proxy-types.h +++ b/include/mp/proxy-types.h @@ -662,6 +662,19 @@ decltype(auto) CustomReadField(TypeList param, return read_dest.update([&](auto& value) { ReadOne<0>(param, invoke_context, input, value); }); } +//! Overload CustomReadField to serialize objects that have CustomReadMessage +//! overloads. Defining a CustomReadMessage overload is simpler than defining a +//! CustomReadField overload because it only requires defining a normal +//! function, not a template function, but less flexible. +template +decltype(auto) CustomReadField(TypeList, Priority<2>, InvokeContext& invoke_context, Reader&& reader, + ReadDest&& read_dest, + decltype(CustomReadMessage(invoke_context, reader.get(), + std::declval()))* enable = nullptr) +{ + return read_dest.update([&](auto& value) { if (reader.has()) CustomReadMessage(invoke_context, reader.get(), value); }); +} + template decltype(auto) ReadField(TypeList, Args&&... args) { @@ -719,6 +732,17 @@ bool CustomHasValue(InvokeContext& invoke_context, Values&&... value) return true; } +//! Overload CustomBuildField to serialize objects that have CustomBuildMessage +//! overloads. Defining a CustomBuildMessage overload is simpler than defining a +//! CustomBuildField overload because it only requires defining a normal +//! function, not a template function, but less flexible. +template +void CustomBuildField(TypeList, Priority<2>, InvokeContext& invoke_context, Value&& value, Output&& output, + decltype(CustomBuildMessage(invoke_context, value, std::move(output.get())))* enable = nullptr) +{ + CustomBuildMessage(invoke_context, value, std::move(output.init())); +} + template void BuildField(TypeList, Context& context, Output&& output, Values&&... values) { @@ -1389,10 +1413,61 @@ struct ServerExcept : Parent } }; +//! Helper for CustomPassField below. Call Accessor::get method if it has one, +//! otherwise return capnp::Void. +template +decltype(auto) MaybeGet(Message&& message, decltype(Accessor::get(message))* enable = nullptr) +{ + return Accessor::get(message); +} + +template +::capnp::Void MaybeGet(...) +{ + return {}; +} + +//! Helper for CustomPassField below. Call Accessor::init method if it has one, +//! otherwise do nothing. +template +decltype(auto) MaybeInit(Message&& message, decltype(Accessor::get(message))* enable = nullptr) +{ + return Accessor::init(message); +} + +template +::capnp::Void MaybeInit(...) +{ + return {}; +} + +//! Overload CustomPassField to serialize objects that have CustomPassMessage +//! overloads. Defining a CustomPassMessage overload is simpler than defining a +//! CustomPassField overload because it only requires defining a normal +//! function, not a template function, but less flexible. +template +auto CustomPassField(TypeList, ServerContext& server_context, Fn&& fn, Args&&... args) + -> decltype(CustomPassMessage(server_context, MaybeGet(server_context.call_context.getParams()), + MaybeGet(server_context.call_context.getResults()), nullptr)) +{ + CustomPassMessage(server_context, MaybeGet(server_context.call_context.getParams()), + MaybeInit(server_context.call_context.getResults()), + [&](LocalTypes... param) { fn.invoke(server_context, std::forward(args)..., param...); }); +} + template void CustomPassField(); //! PassField override calling CustomPassField function, if it exists. +//! Defining a CustomPassField or CustomPassMessage overload is useful for +//! input/output parameters. If an overload is not defined these parameters will +//! just be deserialized on the server side with ReadField into a temporary +//! variable, then the server method will be called passing the temporary +//! variable as a parameter, then the temporary variable will be serialized and +//! sent back to the client with BuildField. But if a PassField or PassMessage +//! overload is defined, the overload is called with a callback to invoke and +//! pass parameters to the server side function, and run arbitrary code before +//! and after invoking the function. template auto PassField(Priority<2>, Args&&... args) -> decltype(CustomPassField(std::forward(args)...)) { diff --git a/test/mp/test/foo-types.h b/test/mp/test/foo-types.h index 41961eb7..347be20b 100644 --- a/test/mp/test/foo-types.h +++ b/test/mp/test/foo-types.h @@ -28,6 +28,45 @@ decltype(auto) CustomReadField(TypeList, Priority<1>, InvokeContext& } } // namespace test + +inline void CustomBuildMessage(InvokeContext& invoke_context, + const test::FooMessage& src, + test::messages::FooMessage::Builder&& builder) +{ + builder.setMessage(src.message + " build"); +} + +inline void CustomReadMessage(InvokeContext& invoke_context, + const test::messages::FooMessage::Reader& reader, + test::FooMessage& dest) +{ + dest.message = std::string{reader.getMessage()} + " read"; +} + +inline void CustomBuildMessage(InvokeContext& invoke_context, + const test::FooMutable& src, + test::messages::FooMutable::Builder&& builder) +{ + builder.setMessage(src.message + " build"); +} + +inline void CustomReadMessage(InvokeContext& invoke_context, + const test::messages::FooMutable::Reader& reader, + test::FooMutable& dest) +{ + dest.message = std::string{reader.getMessage()} + " read"; +} + +inline void CustomPassMessage(InvokeContext& invoke_context, + const test::messages::FooMutable::Reader& reader, + test::messages::FooMutable::Builder builder, + std::function&& fn) +{ + test::FooMutable mut; + mut.message = std::string{reader.getMessage()} + " pass"; + fn(mut); + builder.setMessage(mut.message + " return"); +} } // namespace mp #endif // MP_TEST_FOO_TYPES_H diff --git a/test/mp/test/foo.capnp b/test/mp/test/foo.capnp index 9080ee3c..caa87ea8 100644 --- a/test/mp/test/foo.capnp +++ b/test/mp/test/foo.capnp @@ -25,6 +25,8 @@ interface FooInterface $Proxy.wrap("mp::test::FooImplementation") { callbackExtended @10 (context :Proxy.Context, callback :ExtendedCallback, arg: Int32) -> (result :Int32); passCustom @11 (arg :FooCustom) -> (result :FooCustom); passEmpty @12 (arg :FooEmpty) -> (result :FooEmpty); + passMessage @13 (arg :FooMessage) -> (result :FooMessage); + passMutable @14 (arg :FooMutable) -> (arg :FooMutable); } interface FooCallback $Proxy.wrap("mp::test::FooCallback") { @@ -50,6 +52,14 @@ struct FooCustom $Proxy.wrap("mp::test::FooCustom") { struct FooEmpty $Proxy.wrap("mp::test::FooEmpty") { } +struct FooMessage { + message @0 :Text; +} + +struct FooMutable { + message @0 :Text; +} + struct Pair(T1, T2) { first @0 :T1; second @1 :T2; diff --git a/test/mp/test/foo.h b/test/mp/test/foo.h index 43f4205b..a98fa19e 100644 --- a/test/mp/test/foo.h +++ b/test/mp/test/foo.h @@ -31,6 +31,16 @@ struct FooEmpty { }; +struct FooMessage +{ + std::string message; +}; + +struct FooMutable +{ + std::string message; +}; + class FooCallback { public: @@ -60,6 +70,8 @@ class FooImplementation int callbackExtended(ExtendedCallback& callback, int arg) { return callback.callExtended(arg); } FooCustom passCustom(FooCustom foo) { return foo; } FooEmpty passEmpty(FooEmpty foo) { return foo; } + FooMessage passMessage(FooMessage foo) { foo.message += " call"; return foo; } + void passMutable(FooMutable& foo) { foo.message += " call"; } std::shared_ptr m_callback; }; diff --git a/test/mp/test/test.cpp b/test/mp/test/test.cpp index ed845cd6..f58b0c69 100644 --- a/test/mp/test/test.cpp +++ b/test/mp/test/test.cpp @@ -108,6 +108,16 @@ KJ_TEST("Call FooInterface methods") foo->passEmpty(FooEmpty{}); + FooMessage message1; + message1.message = "init"; + FooMessage message2{foo->passMessage(message1)}; + KJ_EXPECT(message2.message == "init build read call build read"); + + FooMutable mut; + mut.message = "init"; + foo->passMutable(mut); + KJ_EXPECT(mut.message == "init build pass call return read"); + disconnect_client(); thread.join();