From 4697c137257ac0e995498b82735fc1f87038a9dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 16 Nov 2025 14:50:04 +0100 Subject: [PATCH 01/23] make context ref mut --- example-service/src/client.rs | 7 +- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 8 +-- plugins/tests/service.rs | 14 ++-- tarpc/examples/compression.rs | 4 +- tarpc/examples/custom_transport.rs | 4 +- tarpc/examples/pubsub.rs | 23 ++++--- tarpc/examples/readme.rs | 4 +- tarpc/examples/tls_over_tcp.rs | 4 +- tarpc/examples/tracing.rs | 10 +-- tarpc/src/client.rs | 9 ++- tarpc/src/client/stub.rs | 6 +- tarpc/src/client/stub/load_balance.rs | 10 +-- tarpc/src/client/stub/mock.rs | 2 +- tarpc/src/client/stub/retry.rs | 2 +- tarpc/src/context.rs | 2 +- tarpc/src/lib.rs | 9 +-- tarpc/src/server.rs | 64 ++++++++++--------- tarpc/src/server/incoming.rs | 5 +- tarpc/src/server/request_hook.rs | 22 ++++--- tarpc/src/server/request_hook/after.rs | 4 +- tarpc/src/server/request_hook/before.rs | 16 +++-- .../server/request_hook/before_and_after.rs | 6 +- tarpc/src/transport/channel.rs | 6 +- tarpc/tests/dataservice.rs | 4 +- tarpc/tests/service_functional.rs | 52 +++++++++------ 26 files changed, 164 insertions(+), 135 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 8a4ff72eb..c73122c07 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -34,10 +34,13 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { + let mut context = context::current(); + let mut context2 = context::current(); + // Send the request twice, just to be safe! ;) tokio::select! { - hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } - hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 } } } .instrument(tracing::info_span!("Two Hellos")) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 896280c3d..1efab549d 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -35,7 +35,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index da6443edf..55ec2730e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -401,7 +401,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// async fn add(self, context: &mut Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -558,7 +558,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut ::tarpc::context::Context, #( #args ),*) -> #output; } }, ); @@ -622,7 +622,7 @@ impl ServiceGenerator<'_> { type Resp = #response_ident; - async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut ::tarpc::context::Context, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -786,7 +786,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 26ee1ec39..d38492bd7 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,15 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + async fn two_part(self, _: &mut context::Context, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut context::Context, s: String) -> String { s } - async fn baz(self, _: context::Context) {} + async fn baz(self, _: &mut context::Context) {} } } @@ -39,18 +39,18 @@ fn raw_idents() { impl r#trait for () { async fn r#await( self, - _: context::Context, + _: &mut context::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut context::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) {} + async fn r#async(self, _: &mut context::Context) {} } } @@ -64,7 +64,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: context::Context) {} + async fn foo(self, _: &mut context::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index d66261d19..783f2618a 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -108,7 +108,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -134,7 +134,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(context::current(), "friend".into()).await? + client.hello(&mut context::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 5f5386785..c99825d08 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -21,7 +21,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + async fn ping(self, _: &mut Context) {} } #[tokio::main] @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut tarpc::context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index d61f68c48..c89f9e736 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -80,11 +80,11 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + async fn topics(self, _: &mut context::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -210,7 +210,7 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(context::current()).await { + if let Ok(topics) = subscriber.topics(&mut context::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -263,15 +263,20 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, Some(subscriptions) => subscriptions.clone(), }; let mut publications = Vec::new(); + + for client in subscribers.values_mut() { - publications.push(client.receive(context::current(), topic.clone(), message.clone())); + publications.push(async { + let mut context = context::current(); + client.receive(&mut context, topic.clone(), message.clone()).await + }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until // subscribers ack. Of course, a lot would be different in a real pubsub :) @@ -342,26 +347,26 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(context::current(), "calculus".into(), "sqrt(2)".into()) + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) .await?; publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(context::current(), "history".into(), "napoleon".to_string()) + .publish(&mut context::current(), "history".into(), "napoleon".to_string()) .await?; drop(_subscriber0); publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c328bd884..bb3deadc7 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -23,7 +23,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hello, {name}!") } } @@ -46,7 +46,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), "Stim".to_string()).await?; + let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 968f76c17..d81ea74a1 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -33,7 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) -> String { + async fn ping(self, _: &mut Context) -> String { "🔒".to_owned() } } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut tarpc::context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 79a7026c0..1bace43ce 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -56,7 +56,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } } @@ -70,9 +70,9 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - async fn double(self, _: context::Context, x: i32) -> Result { + async fn double(self, _: &mut context::Context, x: i32) -> Result { self.add_client - .add(context::current(), x, x) + .add(&mut context::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -193,9 +193,9 @@ async fn main() -> anyhow::Result<()> { let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let ctx = context::current(); + let mut ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!("{:?}", double_client.double(&mut ctx, 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 3cf9ff07a..96afc4c5f 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -128,7 +128,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, mut ctx: context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,7 +153,10 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx, + ctx: context::Context { + deadline: ctx.deadline, + trace_context: ctx.trace_context.clone(), + }, span, request_id, request, @@ -881,7 +884,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(current(), "hi".to_string()).await; + let resp = channel.call(&mut current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 85746b7f2..2647c1321 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,7 +24,7 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) + async fn call(&self, ctx: &mut context::Context, request: Self::Req) -> Result; } @@ -35,7 +35,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -46,7 +46,7 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { + async fn call(&self, ctx: &mut context::Context, req: Self::Req) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index d28a3c137..6c0f7b0df 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -20,7 +20,7 @@ mod round_robin { async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -119,7 +119,7 @@ mod consistent_hash { async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(context::current(), 'a').await?; + let resp = stub.call(&mut context::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(context::current(), 'b').await?; + let resp = stub.call(&mut context::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(context::current(), 'c').await?; + let resp = stub.call(&mut context::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 145c14c1f..6f0540797 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut context::Context, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index a07b05fc5..18c84f25f 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -18,7 +18,7 @@ where async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 8e77cf223..f59d34dd9 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -21,7 +21,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Context { diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 7e1944305..17a06ec57 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -125,7 +125,7 @@ //! //! impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut context::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -158,7 +158,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # async fn hello(self, _: &mut context::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -184,7 +184,8 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let hello = client.hello(context::current(), "Stim".to_string()).await?; +//! let mut context = context::current(); +//! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); //! @@ -279,7 +280,7 @@ pub enum ClientMessage { } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index da3b3ae21..d0cca7ad4 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,7 +76,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve(self, ctx: &mut context::Context, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. @@ -102,10 +102,9 @@ impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -113,16 +112,15 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -360,10 +358,11 @@ where /// let mut requests = server.requests(); /// tokio::spawn(async move { /// while let Some(Ok(request)) = requests.next().await { - /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` fn requests(self) -> Requests @@ -399,12 +398,13 @@ where /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( - /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) })) + /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); + /// }.boxed())); + /// let mut context = context::current(); /// assert_eq!( - /// client.call(context::current(), 1).await.unwrap(), + /// client.call(&mut context, 1).await.unwrap(), /// 2); /// } /// ``` @@ -748,11 +748,12 @@ where /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( - /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// requests.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// }.boxed())); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub fn execute(self, serve: S) -> impl Stream> @@ -855,11 +856,11 @@ impl InFlightRequest { /// tokio::spawn(async move { /// let mut requests = server.requests(); /// while let Some(Ok(in_flight_request)) = requests.next().await { - /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } - /// /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` /// @@ -875,7 +876,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -883,7 +884,7 @@ impl InFlightRequest { span.record("otel.name", message.name()); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, @@ -977,6 +978,7 @@ mod tests { task::Poll, time::{Duration, Instant}, }; + use tracing_subscriber::filter::FilterExt; fn test_channel() -> ( Pin, Response>>>>, @@ -1039,8 +1041,8 @@ mod tests { #[tokio::test] async fn test_serve() { - let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + let serve = serve(|_, i| async move { Ok(i) }.boxed()); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1060,14 +1062,14 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: context::Context, i| async move { + let serve = serve(move |ctx: &mut context::Context, i| async move { assert_eq!(ctx.deadline, some_time); Ok(i) - }); + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; + deadline_hook.serve(&mut ctx, 7).await?; Ok(()) } @@ -1101,21 +1103,21 @@ mod tests { } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { - let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1320,7 +1322,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) })).await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut() diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 428eb1a7d..eddf3794e 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -63,9 +63,10 @@ where /// /// let incoming = stream::once(async move { /// BaseChannel::new(server::Config::default(), rx) -/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); +/// let mut context = context::current(); +/// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub async fn spawn_incoming( diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 66cf2878c..64b97453a 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -43,11 +43,11 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// - /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) /// .before(|_ctx: &mut context::Context, req: &i32| { /// future::ready( /// if *req == 1 { @@ -58,7 +58,8 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn before(self, hook: Hook) -> HookThenServe @@ -80,7 +81,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// @@ -93,15 +94,15 @@ pub trait RequestHook: Serve { /// } else { /// Ok(i + 1) /// } - /// }) + /// }.boxed()) /// .after(|_ctx: &mut context::Context, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> ServeThenHook @@ -123,7 +124,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{ /// context, ServerError, /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} @@ -151,8 +152,9 @@ pub trait RequestHook: Serve { /// /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) - /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// }.boxed()).before_and_after(PrintLatency(Instant::now())); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index b2ef9ccbd..e2c49b2f1 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -59,14 +59,14 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Serv::Req, ) -> Result { let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index e72e28a42..ad04cc784 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -87,13 +87,13 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Self::Req, ) -> Result { let HookThenServe { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } @@ -103,7 +103,7 @@ where /// Example /// /// ```rust -/// use futures::{executor::block_on, future}; +/// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, /// BeforeRequest, BeforeRequestList}}}; /// use std::{cell::Cell, io}; @@ -120,8 +120,9 @@ where /// i.set(2); /// Ok(()) /// }) -/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); -/// let response = serve.clone().serve(context::current(), 1); +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); +/// let mut context = context::current(); +/// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); /// ``` @@ -209,8 +210,9 @@ fn before_request_list() { i.set(2); Ok(()) }) - .serving(serve(|_ctx, i| async move { Ok(i + 1) })); - let response = serve.clone().serve(context::current(), 1); + .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); + let mut context = context::current(); + let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 0761a7df3..e06f34113 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,13 +46,13 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 0268300dc..3c0c420aa 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -198,7 +198,7 @@ mod tests { format!("{request:?} is not an int"), ) }) - })) + }.boxed())) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), @@ -206,8 +206,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client.call(&mut context::current(), "123".into()).await; + let response2 = client.call(&mut context::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 18bb3a997..e051b434e 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +53,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(context::current(), TestData::White) + .get_opposite_color(&mut context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 06542b43b..1005ae116 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -22,11 +22,11 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -38,10 +38,10 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) + .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) .for_each(|response| response), ); - assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -55,7 +55,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + async fn r#loop(self, _: &mut context::Context) { loop { futures::pending!(); } @@ -73,7 +73,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let mut ctx = context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); - let _ = client.r#loop(ctx).await; + let _ = client.r#loop(&mut ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -112,9 +112,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(context::current(), "Tim".to_string()).await, + client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -145,8 +145,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(context::current(), 1, 2).await; - let res2 = client.hey(context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::current(), 1, 2).await; + let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -169,12 +169,15 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context = context::current(); + let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); + + let req2 = client.add(&mut context, 3, 4); assert_matches!(req2.await, Ok(7)); + + let req3 = client.hey(&mut context, "Tim".to_string()); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) @@ -195,9 +198,13 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context1 = context::current(); + let mut context2 = context::current(); + let mut context3 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); + let req3 = client.hey(&mut context3, "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -225,8 +232,11 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); + let mut context1 = context::current(); + let mut context2 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); @@ -245,7 +255,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + async fn count(self, _: &mut context::Context) -> u32 { self.0 += 1; self.0 } @@ -262,8 +272,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(context::current()).await, Ok(1)); - assert_matches!(client.count(context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::current()).await, Ok(2)); Ok(()) } From 1b605a3c48bfdd61db0467191c66bd23550b9c3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 20:57:48 +0100 Subject: [PATCH 02/23] run cargo fmt --- tarpc/examples/compression.rs | 4 +++- tarpc/examples/pubsub.rs | 11 ++++++++--- tarpc/examples/readme.rs | 4 +++- tarpc/src/client/stub.rs | 13 ++++++++++--- tarpc/src/server.rs | 31 +++++++++++++++++++++++-------- tarpc/src/transport/channel.rs | 19 +++++++++++-------- tarpc/tests/service_functional.rs | 4 +++- 7 files changed, 61 insertions(+), 25 deletions(-) diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 783f2618a..e703cc676 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -134,7 +134,9 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(&mut context::current(), "friend".into()).await? + client + .hello(&mut context::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index c89f9e736..4e132616f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -271,11 +271,12 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); - for client in subscribers.values_mut() { publications.push(async { let mut context = context::current(); - client.receive(&mut context, topic.clone(), message.clone()).await + client + .receive(&mut context, topic.clone(), message.clone()) + .await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -359,7 +360,11 @@ async fn main() -> anyhow::Result<()> { .await?; publisher - .publish(&mut context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut context::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index bb3deadc7..c00c270f0 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -46,7 +46,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 2647c1321..14b6edf30 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,8 +24,11 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: &mut context::Context, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut context::Context, + request: Self::Req, + ) -> Result; } impl Stub for Channel @@ -46,7 +49,11 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: &mut context::Context, req: Self::Req) -> Result { + async fn call( + &self, + ctx: &mut context::Context, + req: Self::Req, + ) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d0cca7ad4..e08365964 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,7 +76,11 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: &mut context::Context, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut context::Context, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. @@ -104,7 +108,10 @@ impl Copy for ServeFn where F: Copy {} /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::Context, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -115,7 +122,10 @@ where impl Serve for ServeFn where Req: RequestName, - for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::Context, + Req, + ) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; @@ -1062,10 +1072,13 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::Context, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - }.boxed()); + let serve = serve(move |ctx: &mut context::Context, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() + }); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; @@ -1322,7 +1335,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 3c0c420aa..e064e6813 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -191,14 +191,17 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) - }.boxed())) + .execute(serve(|_ctx, request: String| { + async move { + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + } + .boxed() + })) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 1005ae116..f3adda2fb 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -38,7 +38,9 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) + .execute(tarpc::server::serve(|_, i: u32| { + async move { Ok(i + 1) }.boxed() + })) .for_each(|response| response), ); assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); From 02ca335e504c2051ecf58d235d4990310dd81af1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 20:58:52 +0100 Subject: [PATCH 03/23] cargo clippy --- tarpc/src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 96afc4c5f..9ef7a1acb 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -155,7 +155,7 @@ where .send(DispatchRequest { ctx: context::Context { deadline: ctx.deadline, - trace_context: ctx.trace_context.clone(), + trace_context: ctx.trace_context, }, span, request_id, From 8e1dce47fd473d84ff80b3186a8c8b162965726c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 18 Nov 2025 14:19:35 +0100 Subject: [PATCH 04/23] separate context into shared, client and server contexts. only transmit shared context between client and server --- example-service/src/client.rs | 5 +- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 10 +- plugins/tests/service.rs | 14 +-- tarpc/examples/compression.rs | 6 +- tarpc/examples/custom_transport.rs | 6 +- tarpc/examples/pubsub.rs | 20 ++-- tarpc/examples/readme.rs | 6 +- tarpc/examples/tls_over_tcp.rs | 6 +- tarpc/examples/tracing.rs | 9 +- tarpc/src/client.rs | 29 +++--- tarpc/src/client/in_flight_requests.rs | 6 +- tarpc/src/client/stub.rs | 23 +++-- tarpc/src/client/stub/load_balance.rs | 10 +- tarpc/src/client/stub/mock.rs | 2 +- tarpc/src/client/stub/retry.rs | 2 +- tarpc/src/context.rs | 98 ++++++++++++++++--- tarpc/src/lib.rs | 8 +- tarpc/src/server.rs | 90 +++++++---------- tarpc/src/server/incoming.rs | 2 +- tarpc/src/server/request_hook.rs | 14 +-- tarpc/src/server/request_hook/after.rs | 8 +- tarpc/src/server/request_hook/before.rs | 18 ++-- .../server/request_hook/before_and_after.rs | 2 +- tarpc/src/server/testing.rs | 2 +- tarpc/src/transport/channel.rs | 4 +- tarpc/tests/dataservice.rs | 4 +- tarpc/tests/service_functional.rs | 36 +++---- 28 files changed, 246 insertions(+), 196 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index c73122c07..dc7104bfd 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -10,6 +10,7 @@ use std::{net::SocketAddr, time::Duration}; use tarpc::{client, context, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; +use tarpc::context::ClientContext; #[derive(Parser)] struct Flags { @@ -34,8 +35,8 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { - let mut context = context::current(); - let mut context2 = context::current(); + let mut context = ClientContext::current(); + let mut context2 = ClientContext::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 1efab549d..0845783c7 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -35,7 +35,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: &mut context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 55ec2730e..886b85b48 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -375,7 +375,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext}; /// /// #[service] /// pub trait Calculator { @@ -401,7 +401,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: &mut Context, a: i32, b: i32) -> i32 { +/// async fn add(self, context: &mut ServerContext, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -558,7 +558,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: &mut ::tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut ::tarpc::context::ServerContext, #( #args ),*) -> #output; } }, ); @@ -622,7 +622,7 @@ impl ServiceGenerator<'_> { type Resp = #response_ident; - async fn serve(self, ctx: &mut ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut ::tarpc::context::ServerContext, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -786,7 +786,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::ClientContext, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index d38492bd7..b03f3470f 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,15 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: &mut context::Context, s: String, i: i32) -> (String, i32) { + async fn two_part(self, _: &mut context::ServerContext, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: &mut context::Context, s: String) -> String { + async fn bar(self, _: &mut context::ServerContext, s: String) -> String { s } - async fn baz(self, _: &mut context::Context) {} + async fn baz(self, _: &mut context::ServerContext) {} } } @@ -39,18 +39,18 @@ fn raw_idents() { impl r#trait for () { async fn r#await( self, - _: &mut context::Context, + _: &mut context::ServerContext, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: &mut context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut context::ServerContext, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: &mut context::Context) {} + async fn r#async(self, _: &mut context::ServerContext) {} } } @@ -64,7 +64,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: &mut context::Context) {} + async fn foo(self, _: &mut context::ServerContext) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index e703cc676..663236731 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -108,7 +108,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: &mut context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}!") } } @@ -134,9 +134,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client - .hello(&mut context::current(), "friend".into()) - .await? + client.hello(&mut context::ClientContext::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index c99825d08..1c682173d 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::Context; +use tarpc::context::{ClientContext, ServerContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -21,7 +21,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: &mut Context) {} + async fn ping(self, _: &mut ServerContext) {} } #[tokio::main] @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut tarpc::context::current()) + .ping(&mut ClientContext::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 4e132616f..83c1371b9 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -80,11 +80,11 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: &mut context::Context) -> Vec { + async fn topics(self, _: &mut context::ServerContext) -> Vec { self.topics.clone() } - async fn receive(self, _: &mut context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::ServerContext, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -210,7 +210,7 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(&mut context::current()).await { + if let Ok(topics) = subscriber.topics(&mut context::ClientContext::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -263,7 +263,7 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: &mut context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::ServerContext, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, @@ -271,12 +271,10 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); + for client in subscribers.values_mut() { publications.push(async { - let mut context = context::current(); - client - .receive(&mut context, topic.clone(), message.clone()) - .await + client.receive(&mut context::ClientContext::current(), topic.clone(), message.clone()).await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -360,11 +358,7 @@ async fn main() -> anyhow::Result<()> { .await?; publisher - .publish( - &mut context::current(), - "history".into(), - "napoleon".to_string(), - ) + .publish(&mut context::current(), "history".into(), "napoleon".to_string()) .await?; drop(_subscriber0); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c00c270f0..60daf4e45 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -23,7 +23,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: &mut context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hello, {name}!") } } @@ -46,9 +46,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client - .hello(&mut context::current(), "Stim".to_string()) - .await?; + let hello = client.hello(&mut context::ClientContext::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index d81ea74a1..cc3c1690b 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -18,7 +18,7 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::Context; +use tarpc::context::{ClientContext, ServerContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -33,7 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: &mut Context) -> String { + async fn ping(self, _: &mut ServerContext) -> String { "🔒".to_owned() } } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut tarpc::context::current()) + .ping(&mut ClientContext::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 1bace43ce..be1b539c1 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -56,7 +56,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } } @@ -70,9 +70,9 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - async fn double(self, _: &mut context::Context, x: i32) -> Result { + async fn double(self, _: &mut context::ServerContext, x: i32) -> Result { self.add_client - .add(&mut context::current(), x, x) + .add(&mut context::ClientContext::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -193,9 +193,8 @@ async fn main() -> anyhow::Result<()> { let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let mut ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(&mut ctx, 1).await?); + tracing::info!("{:?}", double_client.double(&mut context::ClientContext::current(), 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 9ef7a1acb..8d3b9f4a7 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -128,7 +128,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::SharedContext, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,10 +153,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + ctx: ctx.clone(), span, request_id, request, @@ -460,7 +457,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -516,13 +513,15 @@ where let request = ClientMessage::Request(Request { id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx.clone(), }); + + //TODO: Feels like we could avoid either saving the request context in insert_request + // or submitting the context in start_request. + let full_context = context::ClientContext::new(ctx); + self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request(request_id, full_context, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -717,7 +716,7 @@ mod tests { dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request(0, ClientContext::current(), Span::current(), tx) .unwrap(); server_channel .send(Response { @@ -884,7 +883,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(&mut current(), "hi".to_string()).await; + let resp = channel.call(&mut ClientContext::current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -1094,7 +1093,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: SharedContext::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1119,7 +1118,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: SharedContext::current(), span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 1776a74a0..a368a5a48 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -29,7 +29,7 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { - ctx: context::Context, + ctx: context::ClientContext, span: Span, response_completion: oneshot::Sender, /// The key to remove the timer for the request's deadline. @@ -56,7 +56,7 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::Context, + ctx: context::ClientContext, span: Span, response_completion: oneshot::Sender, ) -> Result<(), AlreadyExistsError> { @@ -106,7 +106,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::ClientContext, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 14b6edf30..c7dc12008 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,11 +24,8 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call( - &self, - ctx: &mut context::Context, - request: Self::Req, - ) -> Result; + async fn call(&self, ctx: &mut context::ClientContext, request: Self::Req) + -> Result; } impl Stub for Channel @@ -38,7 +35,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -49,11 +46,13 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call( - &self, - ctx: &mut context::Context, - req: Self::Req, - ) -> Result { - self.clone().serve(ctx, req).await.map_err(RpcError::Server) + async fn call(&self, ctx: &mut context::ClientContext, req: Self::Req) -> Result { + let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); + + let res = self.clone().serve(&mut server_ctx, req).await.map_err(RpcError::Server); + + ctx.shared_context = server_ctx.shared_context; + + res } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 6c0f7b0df..bf70ebe2a 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -20,7 +20,7 @@ mod round_robin { async fn call( &self, - ctx: &mut context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -119,7 +119,7 @@ mod consistent_hash { async fn call( &self, - ctx: &mut context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(&mut context::current(), 'a').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(&mut context::current(), 'b').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(&mut context::current(), 'c').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 6f0540797..451544433 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: &mut context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut context::ClientContext, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 18c84f25f..d93daa156 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -18,7 +18,7 @@ where async fn call( &self, - ctx: &mut context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index f59d34dd9..a96c49095 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -14,6 +14,7 @@ use std::{ convert::TryFrom, time::{Duration, Instant}, }; +use std::ops::{Deref, DerefMut}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. @@ -21,10 +22,10 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Debug)] +#[derive(Debug, Clone)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Context { +pub struct SharedContext { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -38,6 +39,86 @@ pub struct Context { pub trace_context: trace::Context, } +/// Request context that carries request-scoped server side information like deadlines and trace information +/// as well as any server side extensions defined by the transport, hooks or service implementations. +/// It is build from the shared context sent from client to server. +/// +/// The context should not be stored directly in a server implementation, because the context will +/// be different for each request in scope. +#[derive(Debug)] +pub struct ServerContext { + /// Shared context sent from client to server which contains information used by both sides. + pub shared_context: SharedContext, +} + +impl ServerContext { + /// Creates a new ServerContext from the given SharedContext with no extensions. + pub fn new(shared_context: SharedContext) -> Self { + Self { + shared_context, + } + } + + /// Creates a new ServerContext for the current shared context with no extensions. + pub fn current() -> Self { + Self::new(SharedContext::current()) + } +} + +impl Deref for ServerContext { + type Target = SharedContext; + + fn deref(&self) -> &Self::Target { + &self.shared_context + } +} +impl DerefMut for ServerContext { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.shared_context + } +} + +/// Request context that carries request-scoped client side information like deadlines and trace information +/// as well as any server side extensions defined by the transport, hooks and stubs. +/// The shared part of the context is sent from client to server, while the client side extensions are only seen on the client side. +/// +/// The context should not be stored directly in a stub implementation, because the context will +/// be different for each request in scope. +#[derive(Debug)] +pub struct ClientContext { + /// Shared context sent from client to server which contains information used by both sides. + pub shared_context: SharedContext, + +} + +impl ClientContext { + /// Creates a new ServerContext from the given SharedContext with no extensions. + pub fn new(shared_context: SharedContext) -> Self { + Self { + shared_context, + } + } + + /// Creates a new ServerContext for the current shared context with no extensions. + pub fn current() -> Self { + Self::new(SharedContext::current()) + } +} + +impl Deref for ClientContext { + type Target = SharedContext; + + fn deref(&self) -> &Self::Target { + &self.shared_context + } +} + +impl DerefMut for ClientContext { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.shared_context + } +} + #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -91,17 +172,12 @@ mod absolute_to_relative_time { } } -assert_impl_all!(Context: Send, Sync); +assert_impl_all!(SharedContext: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } -/// Returns the context for the current request, or a default Context if no request is active. -pub fn current() -> Context { - Context::current() -} - #[derive(Clone)] struct Deadline(Instant); @@ -111,7 +187,7 @@ impl Default for Deadline { } } -impl Context { +impl SharedContext { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -137,11 +213,11 @@ impl Context { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &Context); + fn set_context(&self, context: &SharedContext); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &Context) { + fn set_context(&self, context: &SharedContext) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 17a06ec57..a83efae02 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -125,7 +125,7 @@ //! //! impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: &mut context::Context, name: String) -> String { +//! async fn hello(self, _: &mut context::ServerContext, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -158,7 +158,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: &mut context::Context, name: String) -> String { +//! # async fn hello(self, _: &mut context::ServerContext, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -184,7 +184,7 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::current(); +//! let mut context = context::ClientContext::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); @@ -284,7 +284,7 @@ pub enum ClientMessage { #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::Context, + pub context: context::SharedContext, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index e08365964..3b01d207d 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,11 +76,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve( - self, - ctx: &mut context::Context, - req: Self::Req, - ) -> Result; + async fn serve(self, ctx: &mut context::ServerContext, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. @@ -108,10 +104,7 @@ impl Copy for ServeFn where F: Copy {} /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce( - &'a mut context::Context, - Req, - ) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -122,15 +115,12 @@ where impl Serve for ServeFn where Req: RequestName, - for<'a> F: FnOnce( - &'a mut context::Context, - Req, - ) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -371,7 +361,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -412,7 +402,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -762,7 +752,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -869,7 +859,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -886,15 +876,16 @@ impl InFlightRequest { span, request: Request { - mut context, + context, message, id: request_id, }, } = self; span.record("otel.name", message.name()); + let mut full_context = context::ServerContext::new(context); let _ = Abortable::new( async move { - let message = serve.serve(&mut context, message).await; + let message = serve.serve(&mut full_context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, @@ -1037,7 +1028,7 @@ mod tests { fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::current(), + context: context::SharedContext::current(), id: 0, message: req, }) @@ -1052,7 +1043,7 @@ mod tests { #[tokio::test] async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); - assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); + assert_matches!(serve.serve(&mut context::ServerContext::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1061,7 +1052,7 @@ mod tests { impl BeforeRequest for SetDeadline { async fn before( &mut self, - ctx: &mut context::Context, + ctx: &mut context::ServerContext, _: &Req, ) -> Result<(), ServerError> { ctx.deadline = self.0; @@ -1072,15 +1063,12 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::Context, i| { - async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - } - .boxed() - }); + let serve = serve(move |ctx: &mut context::ServerContext, i| async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::current(); + let mut ctx = context::ServerContext::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1103,7 +1091,7 @@ mod tests { impl BeforeRequest for PrintLatency { async fn before( &mut self, - _: &mut context::Context, + _: &mut context::ServerContext, _: &Req, ) -> Result<(), ServerError> { self.start = Instant::now(); @@ -1111,15 +1099,15 @@ mod tests { } } impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::Context, _: &mut Result) { + async fn after(&mut self, _: &mut context::ServerContext, _: &mut Result) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } - let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); + let serve = serve(move |_: &mut context::ServerContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::current(), 7) + .serve(&mut context::ServerContext::current(), 7) .await?; Ok(()) } @@ -1127,10 +1115,10 @@ mod tests { #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); - let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + let deadline_hook = serve.before(|_: &mut context::ServerContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::ServerContext::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1143,14 +1131,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: () }), Err(AlreadyExistsError) @@ -1166,7 +1154,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1174,7 +1162,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1197,7 +1185,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1226,7 +1214,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1268,7 +1256,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1291,7 +1279,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1335,9 +1323,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request - .execute(serve(|_, _| async { Ok(()) }.boxed())) - .await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut() @@ -1358,7 +1344,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1388,7 +1374,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1409,7 +1395,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1428,7 +1414,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index eddf3794e..cb01021f5 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -65,7 +65,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::current(); +/// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 64b97453a..38b0998bf 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -48,7 +48,7 @@ pub trait RequestHook: Serve { /// use std::io; /// /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) - /// .before(|_ctx: &mut context::Context, req: &i32| { + /// .before(|_ctx: &mut context::ServerContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::current(); + /// let mut context = context::ServerContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -95,13 +95,13 @@ pub trait RequestHook: Serve { /// Ok(i + 1) /// } /// }.boxed()) - /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// .after(|_ctx: &mut context::ServerContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// let mut context = context::current(); + /// let mut context = context::ServerContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -134,7 +134,7 @@ pub trait RequestHook: Serve { /// struct PrintLatency(Instant); /// /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + /// async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } @@ -143,7 +143,7 @@ pub trait RequestHook: Serve { /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::Context, + /// _: &mut context::ServerContext, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::current(); + /// let mut context = context::ServerContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index e2c49b2f1..d9e676ca4 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -15,15 +15,15 @@ pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); + async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result); } impl AfterRequest for F where - F: FnMut(&mut context::Context, &mut Result) -> Fut, + F: FnMut(&mut context::ServerContext, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { + async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result) { self(ctx, resp).await } } @@ -59,7 +59,7 @@ where async fn serve( self, - ctx: &mut context::Context, + ctx: &mut context::ServerContext, req: Serv::Req, ) -> Result { let ServeThenHook { diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index ad04cc784..4a1b2ad8a 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -19,7 +19,7 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. @@ -34,7 +34,7 @@ pub trait BeforeRequestList: BeforeRequest { /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::Context, &Req) -> Fut, + Next: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, >( self, @@ -56,10 +56,10 @@ pub trait BeforeRequestList: BeforeRequest { impl BeforeRequest for F where - F: FnMut(&mut context::Context, &Req) -> Fut, + F: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { self(ctx, req).await } } @@ -87,7 +87,7 @@ where async fn serve( self, - ctx: &mut context::Context, + ctx: &mut context::ServerContext, req: Self::Req, ) -> Result { let HookThenServe { @@ -121,7 +121,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::current(); +/// let mut context = context::ServerContext::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -141,7 +141,7 @@ pub struct BeforeRequestNil; impl, Rest: BeforeRequest> BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -150,7 +150,7 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { Ok(()) } } @@ -211,7 +211,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = context::current(); + let mut context = context::ServerContext::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index e06f34113..af37427af 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,7 +46,7 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index db167c42e..70c4e7f69 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -92,7 +92,7 @@ impl FakeChannel>, Response> { let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::Context { + context: context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), }, diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index e064e6813..5cb897569 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -209,8 +209,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(&mut context::current(), "123".into()).await; - let response2 = client.call(&mut context::current(), "abc".into()).await; + let response1 = client.call(&mut context::ClientContext::current(), "123".into()).await; + let response2 = client.call(&mut context::ClientContext::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index e051b434e..e4cbf338d 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: &mut context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::ServerContext, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +53,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::current(), TestData::White) + .get_opposite_color(&mut context::ClientContext::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index f3adda2fb..46ce7bd47 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -22,11 +22,11 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: &mut context::Context, name: String) -> String { + async fn hey(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}.") } } @@ -43,7 +43,7 @@ async fn sequential() { })) .for_each(|response| response), ); - assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::ClientContext::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -57,7 +57,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: &mut context::Context) { + async fn r#loop(self, _: &mut context::ServerContext) { loop { futures::pending!(); } @@ -73,7 +73,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::current(); + let mut ctx = context::ClientContext::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -114,9 +114,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::ClientContext::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(&mut context::current(), "Tim".to_string()).await, + client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -147,8 +147,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(&mut context::current(), 1, 2).await; - let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::ClientContext::current(), 1, 2).await; + let res2 = client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -171,7 +171,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::current(); + let mut context = context::ClientContext::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -200,9 +200,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::current(); - let mut context2 = context::current(); - let mut context3 = context::current(); + let mut context1 = context::ClientContext::current(); + let mut context2 = context::ClientContext::current(); + let mut context3 = context::ClientContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -234,8 +234,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::current(); - let mut context2 = context::current(); + let mut context1 = context::ClientContext::current(); + let mut context2 = context::ClientContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -257,7 +257,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: &mut context::Context) -> u32 { + async fn count(self, _: &mut context::ServerContext) -> u32 { self.0 += 1; self.0 } @@ -274,8 +274,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(&mut context::current()).await, Ok(1)); - assert_matches!(client.count(&mut context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(2)); Ok(()) } From d1afa2cbf7d3db88a9550186e3b6ab874af892ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 20:37:23 +0100 Subject: [PATCH 05/23] allow transports to see and manipulate client and server contexts. --- example-service/src/client.rs | 7 +- example-service/src/server.rs | 12 ++-- plugins/Cargo.toml | 1 + plugins/src/lib.rs | 10 ++- tarpc/examples/compression.rs | 17 ++--- tarpc/examples/custom_transport.rs | 7 +- tarpc/examples/pubsub.rs | 14 ++-- tarpc/examples/readme.rs | 11 +-- tarpc/examples/tls_over_tcp.rs | 5 +- tarpc/examples/tracing.rs | 16 +++-- tarpc/src/client.rs | 45 ++++++------ tarpc/src/client/in_flight_requests.rs | 17 +++-- tarpc/src/context.rs | 5 +- tarpc/src/lib.rs | 43 +++++++++--- tarpc/src/server.rs | 94 +++++++++++++++----------- tarpc/src/server/incoming.rs | 7 +- tarpc/src/server/testing.rs | 6 +- tarpc/src/transport/channel.rs | 43 ++++++++---- tarpc/tests/dataservice.rs | 7 +- tarpc/tests/service_functional.rs | 53 +++++++++++---- 20 files changed, 268 insertions(+), 152 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index dc7104bfd..2984ae49c 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -7,7 +7,8 @@ use clap::Parser; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::{client, context, tokio_serde::formats::Json}; +use futures::{future, SinkExt}; +use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; use tarpc::context::ClientContext; @@ -30,9 +31,11 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); + let transport = transport.await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); + let client = WorldClient::new(client::Config::default(), transport).spawn(); let hello = async move { let mut context = ClientContext::current(); diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 0845783c7..00b3eb1fb 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -15,12 +15,9 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::{ - context, - server::{self, Channel, incoming::Incoming}, - tokio_serde::formats::Json, -}; +use tarpc::{context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, ClientMessage}; use tokio::time; +use tarpc::context::{ServerContext, SharedContext}; #[derive(Parser)] struct Flags { @@ -62,13 +59,14 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().peer_addr().unwrap()); + let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index 8be746c26..eeab84924 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -30,5 +30,6 @@ proc-macro = true [dev-dependencies] assert-type-eq = "0.1.0" futures = "0.3" +futures-util = "0.3.31" serde = { version = "1.0", features = ["derive"] } tarpc = { path = "../tarpc", features = ["serde1"] } diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 886b85b48..bc52cf849 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -376,6 +376,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// /// ```no_run /// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext}; +/// use futures_util::{TryStreamExt, sink::SinkExt}; /// /// #[service] /// pub trait Calculator { @@ -394,6 +395,13 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// // This could be any transport. /// let (client_side, server_side) = transport::channel::unbounded(); /// +/// let client_side = client_side.with(|msg: tarpc::ClientMessage| async move { +/// Ok(msg.map_context(|ctx| ctx.shared_context)) +/// }); +/// let server_side = server_side.map_ok(|msg: tarpc::ClientMessage| +/// msg.map_context(tarpc::context::ServerContext::new) +/// ); +/// /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); /// @@ -738,7 +746,7 @@ impl ServiceGenerator<'_> { ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<#response_ident>> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 663236731..c8c13d1db 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,12 +9,8 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::{ - client, context, - serde_transport::tcp, - server::{BaseChannel, Channel}, - tokio_serde::formats::Bincode, -}; +use tarpc::{client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -120,17 +116,22 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; + let addr = incoming.local_addr(); tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); - BaseChannel::with_defaults(add_compression(transport)) + let transport = add_compression(transport); + let transport = transport.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) .await; }); let transport = tcp::connect(addr, Bincode::default).await?; - let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn(); + let transport = add_compression(transport); + let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( "{}", diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 1c682173d..6abf78a58 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -5,8 +5,8 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext}; -use tarpc::serde_transport as transport; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{serde_transport as transport, ClientMessage}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -23,7 +23,6 @@ struct Service; impl PingService for Service { async fn ping(self, _: &mut ServerContext) {} } - #[tokio::main] async fn main() -> anyhow::Result<()> { let bind_addr = "/tmp/tarpc_on_unix_example.sock"; @@ -40,6 +39,7 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); + let transport = transport.map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -50,6 +50,7 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); + let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 83c1371b9..bf95a2e15 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -48,15 +48,11 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::{ - client, context, - serde_transport::tcp, - server::{self, Channel}, - tokio_serde::formats::Json, -}; +use tarpc::{client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, ClientMessage}; use tokio::net::ToSocketAddrs; use tracing::info; use tracing_subscriber::prelude::*; +use tarpc::context::{ServerContext, SharedContext}; pub mod subscriber { #[tarpc::service] @@ -104,6 +100,7 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; + let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -164,6 +161,8 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); + let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + server::BaseChannel::with_defaults(publisher) .execute(self.serve()) .for_each(spawn) @@ -183,6 +182,7 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); + let conn = conn.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let tarpc::client::NewClient { client: subscriber, @@ -341,7 +341,7 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default).await?, + tcp::connect(addrs.publisher, Json::default).await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))) ) .spawn(); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 60daf4e45..884e298f3 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,10 +5,8 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::{ - client, context, - server::{self, Channel}, -}; +use tarpc::{client, context, server::{self, Channel}, transport, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -34,7 +32,10 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); + let (client_transport, server_transport) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index cc3c1690b..e7307b98d 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -17,8 +17,7 @@ use tokio_rustls::rustls::{ server::{WebPkiClientVerifier, danger::ClientCertVerifier}, }; use tokio_rustls::{TlsAcceptor, TlsConnector}; - -use tarpc::context::{ClientContext, ServerContext}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -115,6 +114,7 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); + let transport = transport.map_ok(|c: tarpc::ClientMessage| c.map_context(|ctx| ServerContext::new(ctx))); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -144,6 +144,7 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); + let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index be1b539c1..66a92738d 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -35,6 +35,7 @@ use tarpc::{ }; use tokio::net::TcpStream; use tracing_subscriber::prelude::*; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; pub mod add { #[tarpc::service] @@ -124,7 +125,7 @@ where } fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], + backends: [impl Transport>, Response> + Send + Sync + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp>>, @@ -173,23 +174,28 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); + let map_context = |msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context)); + let add_client = add::AddClient::from(make_stub([ - tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, - tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?.with(map_context), + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?.with(map_context), ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? - .filter_map(|r| future::ready(r.ok())); - let addr = double_listener.get_ref().local_addr(); + .filter_map(|r| future::ready(r.ok())) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))); + let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; + let to_double_server = to_double_server.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 8d3b9f4a7..f2cf73e24 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -31,6 +31,7 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; +use crate::context::ClientContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -128,7 +129,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, ctx: &mut context::SharedContext, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,7 +154,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx: ctx.clone(), + ctx: ctx.shared_context.clone(), span, request_id, request, @@ -239,7 +240,7 @@ pub fn new( transport: C, ) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -287,7 +288,7 @@ pub struct RequestDispatch { impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, @@ -308,7 +309,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -457,7 +458,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -510,18 +511,20 @@ where // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. + + let trace_context = ctx.trace_context; + let deadline = ctx.deadline; + + let client_context = context::ClientContext::new(ctx); + let request = ClientMessage::Request(Request { id: request_id, message: request, - context: ctx.clone(), + context: client_context, }); - //TODO: Feels like we could avoid either saving the request context in insert_request - // or submitting the context in start_request. - let full_context = context::ClientContext::new(ctx); - self.in_flight_requests() - .insert_request(request_id, full_context, span.clone(), response_completion) + .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -543,14 +546,14 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { + let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { Some(triple) => triple, None => return Poll::Ready(None), }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { - trace_context: context.trace_context, + trace_context, request_id, }; self.start_send(cancel) @@ -640,7 +643,7 @@ where impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { type Output = Result<(), ChannelError>; @@ -710,13 +713,15 @@ mod tests { #[tokio::test] async fn response_completes_request_future() { - let (mut dispatch, mut _channel, mut server_channel) = set_up(); + let (mut dispatch, _channel, mut server_channel) = set_up(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); + let context = ClientContext::current(); + dispatch .in_flight_requests - .insert_request(0, ClientContext::current(), Span::current(), tx) + .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) .unwrap(); server_channel .send(Response { @@ -1052,12 +1057,12 @@ mod tests { RequestDispatch< String, String, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, >, >, >, Channel, - UnboundedChannel, Response>, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1135,7 +1140,7 @@ mod tests { } async fn send_response( - channel: &mut UnboundedChannel, Response>, + channel: &mut UnboundedChannel, Response>, response: Response, ) { channel.send(response).await.unwrap(); diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index a368a5a48..0ffb50c63 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,15 +1,13 @@ -use crate::{ - context, - util::{Compact, TimeUntil}, -}; +use crate::{trace, util::{Compact, TimeUntil}}; use fnv::FnvHashMap; use std::{ collections::hash_map, task::{Context, Poll}, }; +use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; -use tracing::Span; +use tracing::{Span}; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -29,7 +27,7 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { - ctx: context::ClientContext, + ctx: trace::Context, span: Span, response_completion: oneshot::Sender, /// The key to remove the timer for the request's deadline. @@ -56,13 +54,14 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::ClientContext, + ctx: trace::Context, + deadline: Instant, span: Span, response_completion: oneshot::Sender, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { - let timeout = ctx.deadline.time_until(); + let timeout = deadline.time_until(); let deadline_key = self.deadlines.insert(request_id, timeout); vacant.insert(RequestData { ctx, @@ -106,7 +105,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::ClientContext, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(trace::Context, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index a96c49095..e72ab130f 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -10,10 +10,7 @@ use crate::trace::{self, TraceId}; use opentelemetry::trace::TraceContextExt; use static_assertions::assert_impl_all; -use std::{ - convert::TryFrom, - time::{Duration, Instant}, -}; +use std::{convert::TryFrom, time::{Duration, Instant}}; use std::ops::{Deref, DerefMut}; use tracing_opentelemetry::OpenTelemetrySpanExt; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index a83efae02..c097372bc 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -142,7 +142,10 @@ //! # prelude::*, //! # }; //! # use tarpc::{ +//! # ClientMessage, //! # client, context, +//! # context::{ClientContext, ServerContext, SharedContext}, +//! # transport::channel, //! # server::{self, Channel}, //! # }; //! # // This is the service definition. It looks a lot like a trait definition. @@ -167,7 +170,10 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); +//! let (client_transport, server_transport) = channel::unbounded_mapped( +//! |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), +//! |msg: ClientMessage| msg.map_context(ServerContext::new), +//! ); //! //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( @@ -198,7 +204,7 @@ //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -#![deny(missing_docs)] +#![deny(missing_docs, warnings, unused, dead_code)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -252,16 +258,18 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use std::ops::Deref; +use crate::context::{SharedContext}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] -pub enum ClientMessage { +pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which /// the server sends back to the client. - Request(Request), + Request(Request), /// A command to cancel an in-flight request, automatically sent by the client when a response /// future is dropped. /// @@ -279,16 +287,35 @@ pub enum ClientMessage { }, } +impl ClientMessage { + /// Creates a new ClientMessage by mapping the context using the provided function. + pub fn map_context(self, f: F) -> ClientMessage where F: FnOnce(Ctx) -> Ctx2 { + match self { + ClientMessage::Request(Request { context, id, message }) => { + ClientMessage::Request(Request { + context: f(context), + id, + message, + }) + } + ClientMessage::Cancel { trace_context, request_id } => { + ClientMessage::Cancel { trace_context, request_id } + } + } + } +} + + /// A request from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Request { +pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::SharedContext, + pub context: Ctx, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. - pub message: T, + pub message: Req, } /// Implemented by the request types generated by tarpc::service. @@ -491,7 +518,7 @@ impl ServerError { } } -impl Request { +impl Request where Ctx: Deref { /// Returns the deadline for this request. pub fn deadline(&self) -> &Instant { &self.context.deadline diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 3b01d207d..34efc1be6 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -60,7 +60,7 @@ impl Config { /// Returns a channel backed by `transport` and configured with `self`. pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { BaseChannel::new(self, transport) } @@ -154,7 +154,7 @@ pub struct BaseChannel { impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -200,7 +200,7 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, + mut request: Request, ) -> Result, AlreadyExistsError> { let span = info_span!( "RPC", @@ -256,7 +256,7 @@ impl fmt::Debug for BaseChannel { #[derive(Debug)] pub struct TrackedRequest { /// The request sent by the client. - pub request: Request, + pub request: Request, /// A registration to abort a future when the [`Channel`] that produced this request stops /// tracking it. pub abort_registration: AbortRegistration, @@ -341,7 +341,9 @@ where /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, + /// context::{ClientContext, SharedContext, ServerContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -350,7 +352,10 @@ where /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -385,7 +390,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext, ServerContext}}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -394,7 +399,10 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( @@ -420,7 +428,7 @@ where impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Item = Result, ChannelError>; @@ -527,7 +535,7 @@ where impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, { type Error = ChannelError; @@ -580,7 +588,7 @@ impl AsRef for BaseChannel { impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; @@ -736,7 +744,8 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; + /// use tarpc::context::{ClientContext, SharedContext, ServerContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -744,7 +753,11 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); + /// /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( @@ -807,7 +820,7 @@ impl Drop for ResponseGuard { /// be sent to the Channel to clean up associated request state. #[derive(Debug)] pub struct InFlightRequest { - request: Request, + request: Request, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, @@ -816,7 +829,7 @@ pub struct InFlightRequest { impl InFlightRequest { /// Returns a reference to the request. - pub fn get(&self) -> &Request { + pub fn get(&self) -> &Request { &self.request } @@ -839,7 +852,9 @@ impl InFlightRequest { /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, + /// context::{ClientContext, SharedContext, ServerContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -848,7 +863,10 @@ impl InFlightRequest { /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -876,7 +894,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -885,7 +903,7 @@ impl InFlightRequest { let mut full_context = context::ServerContext::new(context); let _ = Abortable::new( async move { - let message = serve.serve(&mut full_context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, @@ -979,11 +997,11 @@ mod tests { task::Poll, time::{Duration, Instant}, }; - use tracing_subscriber::filter::FilterExt; + use crate::context::ServerContext; fn test_channel() -> ( - Pin, Response>>>>, - UnboundedChannel, ClientMessage>, + Pin, Response>>>>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -993,11 +1011,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel, Response>>, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1012,11 +1030,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel, Response>>, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1026,9 +1044,9 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::SharedContext::current(), + context: context::ServerContext::current(), id: 0, message: req, }) @@ -1131,14 +1149,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: () }), Err(AlreadyExistsError) @@ -1154,7 +1172,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1162,7 +1180,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1185,7 +1203,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1214,7 +1232,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1256,7 +1274,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1279,7 +1297,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1344,7 +1362,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1374,7 +1392,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1395,7 +1413,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1414,7 +1432,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index cb01021f5..ad91f0c19 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -48,7 +48,9 @@ where /// # Example /// ```rust /// use tarpc::{ +/// ClientMessage, /// context, +/// context::{ClientContext, ServerContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, @@ -57,7 +59,10 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded(); +/// let (tx, rx) = transport::channel::unbounded_mapped( +/// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), +/// |msg: ClientMessage| msg.map_context(ServerContext::new), +/// ); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 70c4e7f69..ac2201933 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -13,7 +13,7 @@ use crate::{ use futures::{Sink, Stream, task::*}; use pin_project::pin_project; use std::{collections::VecDeque, io, pin::Pin, time::Instant}; -use tracing::Span; +use tracing::{Span}; #[pin_project] pub(crate) struct FakeChannel { @@ -92,10 +92,10 @@ impl FakeChannel>, Response> { let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::SharedContext { + context: context::ServerContext::new(context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), - }, + }), id, message, }, diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 5cb897569..a319ef046 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,10 +6,11 @@ //! Transports backed by in-memory channels. -use futures::{Sink, Stream, task::*}; +use futures::{Sink, Stream, task::*, SinkExt, TryStreamExt}; use pin_project::pin_project; -use std::{error::Error, pin::Pin}; +use std::{error::Error, future, pin::Pin}; use tokio::sync::mpsc; +use crate::Transport; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] @@ -39,6 +40,23 @@ pub fn unbounded() -> ( ) } +/// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's +/// [`Sink`]. +pub fn unbounded_mapped(mut f: F, mut g: G) -> ( + impl Transport, + impl Transport, +) where + F: FnMut(ClientSinkItem) -> SerializedSinkItem + Send + 'static, + G: FnMut(SerializedSinkItem) -> ServerSinkItem + Send + 'static, +{ + let (client, server) = unbounded(); + + let client = client.with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))); + let server = server.map_ok(move |msg: SerializedSinkItem| g(msg)); + + (client, server) +} + /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] @@ -161,20 +179,15 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::{ - ServerError, - client::{self, RpcError}, - context, - server::{BaseChannel, incoming::Incoming, serve}, - transport::{ - self, - channel::{Channel, UnboundedChannel}, - }, - }; + use crate::{ServerError, client::{self, RpcError}, context, server::{BaseChannel, incoming::Incoming, serve}, transport::{ + self, + channel::{Channel, UnboundedChannel}, + }, ClientMessage}; use assert_matches::assert_matches; use futures::{prelude::*, stream}; use std::io; use tracing::trace; + use crate::context::{ClientContext, ServerContext, SharedContext}; #[test] fn ensure_is_transport() { @@ -187,7 +200,11 @@ mod tests { async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (client_channel, server_channel) = transport::channel::unbounded(); + let (client_channel, server_channel) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index e4cbf338d..73f6656d9 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,10 +1,11 @@ use futures::prelude::*; -use tarpc::serde_transport; +use tarpc::{serde_transport, ClientMessage}; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, }; use tokio_serde::formats::Json; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; #[tarpc::derive_serde] #[derive(Debug, PartialEq, Eq)] @@ -43,13 +44,15 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); - let transport = serde_transport::tcp::connect(addr, Json::default).await?; + let transport = serde_transport::tcp::connect(addr, Json::default).await?.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 46ce7bd47..30e4c0743 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,13 +4,9 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::{ - client::{self}, - context, - server::{BaseChannel, Channel, incoming::Incoming}, - transport::channel, -}; +use tarpc::{client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, transport, transport::channel, ClientMessage}; use tokio::join; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; #[tarpc_plugins::service] trait Service { @@ -33,7 +29,11 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = tarpc::transport::channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn( @@ -66,7 +66,11 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -105,6 +109,7 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) + .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -112,6 +117,7 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; + let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!(client.add(&mut context::ClientContext::current(), 1, 2).await, Ok(3)); @@ -137,6 +143,7 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) + .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -144,6 +151,7 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; + let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail @@ -160,7 +168,11 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -169,6 +181,7 @@ async fn concurrent() -> anyhow::Result<()> { .for_each(spawn), ); + let client = ServiceClient::new(client::Config::default(), tx).spawn(); let mut context = context::ClientContext::current(); @@ -189,7 +202,11 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -225,7 +242,11 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -263,14 +284,18 @@ async fn counter() -> anyhow::Result<()> { } } - let (tx, rx) = channel::unbounded(); - tokio::spawn(async { + let (tx, rx) = channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + + tokio::task::spawn(async move { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); while let Some(Ok(request)) = requests.next().await { request.execute(counter.serve()).await; - } + }; }); let client = CounterClient::new(client::Config::default(), tx).spawn(); From 97d0a37bafa5788531820b653c679930420a7092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 21:07:22 +0100 Subject: [PATCH 06/23] run cargo fmt --- example-service/src/client.rs | 10 ++- example-service/src/server.rs | 14 +++- plugins/tests/service.rs | 7 +- tarpc/examples/compression.rs | 19 ++++- tarpc/examples/custom_transport.rs | 9 ++- tarpc/examples/pubsub.rs | 59 ++++++++++---- tarpc/examples/readme.rs | 10 ++- tarpc/examples/tls_over_tcp.rs | 18 +++-- tarpc/examples/tracing.rs | 40 +++++++--- tarpc/src/client.rs | 46 ++++++++--- tarpc/src/client/in_flight_requests.rs | 9 ++- tarpc/src/client/stub.rs | 25 ++++-- tarpc/src/client/stub/load_balance.rs | 12 ++- tarpc/src/client/stub/mock.rs | 6 +- tarpc/src/context.rs | 14 ++-- tarpc/src/lib.rs | 41 ++++++---- tarpc/src/server.rs | 72 ++++++++++++++---- tarpc/src/server/request_hook/after.rs | 12 ++- tarpc/src/server/request_hook/before.rs | 18 ++++- .../server/request_hook/before_and_after.rs | 6 +- tarpc/src/server/testing.rs | 2 +- tarpc/src/transport/channel.rs | 36 ++++++--- tarpc/tests/dataservice.rs | 16 +++- tarpc/tests/service_functional.rs | 76 +++++++++++++++---- 24 files changed, 432 insertions(+), 145 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 2984ae49c..e425c9eb2 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -5,13 +5,13 @@ // https://opensource.org/licenses/MIT. use clap::Parser; +use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use futures::{future, SinkExt}; +use tarpc::context::ClientContext; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; -use tarpc::context::ClientContext; #[derive(Parser)] struct Flags { @@ -31,7 +31,11 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = transport.await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport + .await? + .with(|msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 00b3eb1fb..7e29da291 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -15,9 +15,13 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::{context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, ClientMessage}; -use tokio::time; use tarpc::context::{ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, context, + server::{self, Channel, incoming::Incoming}, + tokio_serde::formats::Json, +}; +use tokio::time; #[derive(Parser)] struct Flags { @@ -59,7 +63,11 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index b03f3470f..756766621 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,7 +12,12 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: &mut context::ServerContext, s: String, i: i32) -> (String, i32) { + async fn two_part( + self, + _: &mut context::ServerContext, + s: String, + i: i32, + ) -> (String, i32) { (s, i) } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index c8c13d1db..46300999e 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,8 +9,13 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::{client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, ClientMessage}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, client, context, + serde_transport::tcp, + server::{BaseChannel, Channel}, + tokio_serde::formats::Bincode, +}; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -121,7 +126,9 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = transport.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + let transport = transport.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) @@ -130,12 +137,16 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( "{}", - client.hello(&mut context::ClientContext::current(), "friend".into()).await? + client + .hello(&mut context::ClientContext::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 6abf78a58..a1b1e4410 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,10 +6,10 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::{serde_transport as transport, ClientMessage}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; +use tarpc::{ClientMessage, serde_transport as transport}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -39,7 +39,8 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); + let transport = transport + .map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -50,7 +51,9 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index bf95a2e15..16195ef3f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -48,11 +48,16 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::{client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, client, context, + serde_transport::tcp, + server::{self, Channel}, + tokio_serde::formats::Json, +}; use tokio::net::ToSocketAddrs; use tracing::info; use tracing_subscriber::prelude::*; -use tarpc::context::{ServerContext, SharedContext}; pub mod subscriber { #[tarpc::service] @@ -100,7 +105,9 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + let publisher = publisher.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -161,7 +168,9 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + let publisher = publisher.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -182,7 +191,11 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = conn.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let conn = conn.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ); let tarpc::client::NewClient { client: subscriber, @@ -210,7 +223,10 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(&mut context::ClientContext::current()).await { + if let Ok(topics) = subscriber + .topics(&mut context::ClientContext::current()) + .await + { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -271,10 +287,15 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); - for client in subscribers.values_mut() { publications.push(async { - client.receive(&mut context::ClientContext::current(), topic.clone(), message.clone()).await + client + .receive( + &mut context::ClientContext::current(), + topic.clone(), + message.clone(), + ) + .await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -341,31 +362,43 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default).await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))) + tcp::connect(addrs.publisher, Json::default).await?.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ), ) .spawn(); publisher - .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) + .publish( + &mut ClientContext::current(), + "calculus".into(), + "sqrt(2)".into(), + ) .await?; publisher .publish( - &mut context::current(), + &mut ClientContext::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(&mut context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut ClientContext::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); publisher .publish( - &mut context::current(), + &mut ClientContext::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 884e298f3..44ae497ff 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,8 +5,12 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::{client, context, server::{self, Channel}, transport, ClientMessage}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, client, context, + server::{self, Channel}, + transport, +}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -47,7 +51,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(&mut context::ClientContext::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::ClientContext::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index e7307b98d..bac4a8048 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,6 +10,11 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::serde_transport as transport; +use tarpc::server::{BaseChannel, Channel}; +use tarpc::tokio_serde::formats::Bincode; +use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -17,11 +22,6 @@ use tokio_rustls::rustls::{ server::{WebPkiClientVerifier, danger::ClientCertVerifier}, }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::serde_transport as transport; -use tarpc::server::{BaseChannel, Channel}; -use tarpc::tokio_serde::formats::Bincode; -use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; #[tarpc::service] pub trait PingService { @@ -114,7 +114,9 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(|c: tarpc::ClientMessage| c.map_context(|ctx| ServerContext::new(ctx))); + let transport = transport.map_ok(|c: tarpc::ClientMessage| { + c.map_context(|ctx| ServerContext::new(ctx)) + }); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -144,7 +146,9 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with(|msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 66a92738d..52b068bc8 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -19,6 +19,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -35,7 +36,6 @@ use tarpc::{ }; use tokio::net::TcpStream; use tracing_subscriber::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; pub mod add { #[tarpc::service] @@ -125,7 +125,8 @@ where } fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], + backends: [impl Transport>, Response> + Send + Sync + 'static; + N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp>>, @@ -174,33 +175,54 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); - let map_context = |msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context)); + let map_context = |msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }; let add_client = add::AddClient::from(make_stub([ - tarpc::serde_transport::tcp::connect(addr1, Json::default).await?.with(map_context), - tarpc::serde_transport::tcp::connect(addr2, Json::default).await?.with(map_context), + tarpc::serde_transport::tcp::connect(addr1, Json::default) + .await? + .with(map_context), + tarpc::serde_transport::tcp::connect(addr2, Json::default) + .await? + .with(map_context), ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))); + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }); let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = to_double_server.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let to_double_server = to_double_server.with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(&mut context::ClientContext::current(), 1).await?); + tracing::info!( + "{:?}", + double_client + .double(&mut context::ClientContext::current(), 1) + .await? + ); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index f2cf73e24..ebcb69db1 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,6 +9,7 @@ mod in_flight_requests; pub mod stub; +use crate::context::ClientContext; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -31,7 +32,6 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; -use crate::context::ClientContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -129,7 +129,11 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { + pub async fn call( + &self, + ctx: &mut context::ClientContext, + request: Req, + ) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -309,7 +313,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send( + self: &mut Pin<&mut Self>, + message: ClientMessage, + ) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -524,7 +531,13 @@ where }); self.in_flight_requests() - .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) + .insert_request( + request_id, + trace_context, + deadline, + span.clone(), + response_completion, + ) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -546,10 +559,11 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some(triple) => triple, - None => return Poll::Ready(None), - }; + let (trace_context, span, request_id) = + match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, + None => return Poll::Ready(None), + }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { @@ -674,7 +688,7 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::Context, + pub ctx: context::SharedContext, pub span: Span, pub request_id: u64, pub request: Req, @@ -686,10 +700,10 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; + use crate::context::{ClientContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, - context::{self, current}, transport::{self, channel::UnboundedChannel}, }; use assert_matches::assert_matches; @@ -721,7 +735,13 @@ mod tests { dispatch .in_flight_requests - .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) + .insert_request( + 0, + context.trace_context, + context.deadline, + Span::current(), + tx, + ) .unwrap(); server_channel .send(Response { @@ -888,7 +908,9 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(&mut ClientContext::current(), "hi".to_string()).await; + let resp = channel + .call(&mut ClientContext::current(), "hi".to_string()) + .await; assert_matches!(resp, Err(RpcError::Shutdown)); } diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 0ffb50c63..7a554de27 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,13 +1,16 @@ -use crate::{trace, util::{Compact, TimeUntil}}; +use crate::{ + trace, + util::{Compact, TimeUntil}, +}; use fnv::FnvHashMap; +use std::time::Instant; use std::{ collections::hash_map, task::{Context, Poll}, }; -use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; -use tracing::{Span}; +use tracing::Span; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index c7dc12008..b99f8e42c 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,8 +24,11 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: &mut context::ClientContext, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut context::ClientContext, + request: Self::Req, + ) -> Result; } impl Stub for Channel @@ -35,7 +38,11 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { + async fn call( + &self, + ctx: &mut context::ClientContext, + request: Req, + ) -> Result { Self::call(self, ctx, request).await } } @@ -46,10 +53,18 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: &mut context::ClientContext, req: Self::Req) -> Result { + async fn call( + &self, + ctx: &mut context::ClientContext, + req: Self::Req, + ) -> Result { let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); - let res = self.clone().serve(&mut server_ctx, req).await.map_err(RpcError::Server); + let res = self + .clone() + .serve(&mut server_ctx, req) + .await + .map_err(RpcError::Server); ctx.shared_context = server_ctx.shared_context; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index bf70ebe2a..62c8bf677 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -200,13 +200,19 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(&mut context::ClientContext::current(), 'a').await?; + let resp = stub + .call(&mut context::ClientContext::current(), 'a') + .await?; assert_eq!(resp, 1); - let resp = stub.call(&mut context::ClientContext::current(), 'b').await?; + let resp = stub + .call(&mut context::ClientContext::current(), 'b') + .await?; assert_eq!(resp, 2); - let resp = stub.call(&mut context::ClientContext::current(), 'c').await?; + let resp = stub + .call(&mut context::ClientContext::current(), 'c') + .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 451544433..bebd8fc99 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,11 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: &mut context::ClientContext, request: Self::Req) -> Result { + async fn call( + &self, + _: &mut context::ClientContext, + request: Self::Req, + ) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index e72ab130f..bbbc3721d 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -10,8 +10,11 @@ use crate::trace::{self, TraceId}; use opentelemetry::trace::TraceContextExt; use static_assertions::assert_impl_all; -use std::{convert::TryFrom, time::{Duration, Instant}}; use std::ops::{Deref, DerefMut}; +use std::{ + convert::TryFrom, + time::{Duration, Instant}, +}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. @@ -51,9 +54,7 @@ pub struct ServerContext { impl ServerContext { /// Creates a new ServerContext from the given SharedContext with no extensions. pub fn new(shared_context: SharedContext) -> Self { - Self { - shared_context, - } + Self { shared_context } } /// Creates a new ServerContext for the current shared context with no extensions. @@ -85,15 +86,12 @@ impl DerefMut for ServerContext { pub struct ClientContext { /// Shared context sent from client to server which contains information used by both sides. pub shared_context: SharedContext, - } impl ClientContext { /// Creates a new ServerContext from the given SharedContext with no extensions. pub fn new(shared_context: SharedContext) -> Self { - Self { - shared_context, - } + Self { shared_context } } /// Creates a new ServerContext for the current shared context with no extensions. diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index c097372bc..cf3423eb5 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -257,9 +257,9 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use crate::context::SharedContext; use std::ops::Deref; -use crate::context::{SharedContext}; +use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; /// A message from a client to a server. #[derive(Debug)] @@ -289,23 +289,31 @@ pub enum ClientMessage { impl ClientMessage { /// Creates a new ClientMessage by mapping the context using the provided function. - pub fn map_context(self, f: F) -> ClientMessage where F: FnOnce(Ctx) -> Ctx2 { + pub fn map_context(self, f: F) -> ClientMessage + where + F: FnOnce(Ctx) -> Ctx2, + { match self { - ClientMessage::Request(Request { context, id, message }) => { - ClientMessage::Request(Request { - context: f(context), - id, - message, - }) - } - ClientMessage::Cancel { trace_context, request_id } => { - ClientMessage::Cancel { trace_context, request_id } - } + ClientMessage::Request(Request { + context, + id, + message, + }) => ClientMessage::Request(Request { + context: f(context), + id, + message, + }), + ClientMessage::Cancel { + trace_context, + request_id, + } => ClientMessage::Cancel { + trace_context, + request_id, + }, } } } - /// A request from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] @@ -518,7 +526,10 @@ impl ServerError { } } -impl Request where Ctx: Deref { +impl Request +where + Ctx: Deref, +{ /// Returns the deadline for this request. pub fn deadline(&self) -> &Instant { &self.context.deadline diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 34efc1be6..559d80ba8 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,6 +6,7 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. +use crate::context::ServerContext; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -76,7 +77,11 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: &mut context::ServerContext, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut context::ServerContext, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. @@ -104,7 +109,10 @@ impl Copy for ServeFn where F: Copy {} /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::ServerContext, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -115,7 +123,10 @@ where impl Serve for ServeFn where Req: RequestName, - for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::ServerContext, + Req, + ) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; @@ -900,7 +911,6 @@ impl InFlightRequest { }, } = self; span.record("otel.name", message.name()); - let mut full_context = context::ServerContext::new(context); let _ = Abortable::new( async move { let message = serve.serve(&mut context, message).await; @@ -980,6 +990,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; + use crate::context::ServerContext; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -997,10 +1008,17 @@ mod tests { task::Poll, time::{Duration, Instant}, }; - use crate::context::ServerContext; fn test_channel() -> ( - Pin, Response>>>>, + Pin< + Box< + BaseChannel< + Req, + Resp, + UnboundedChannel, Response>, + >, + >, + >, UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); @@ -1011,7 +1029,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + UnboundedChannel, Response>, + >, >, >, >, @@ -1030,7 +1052,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + channel::Channel, Response>, + >, >, >, >, @@ -1061,7 +1087,10 @@ mod tests { #[tokio::test] async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); - assert_matches!(serve.serve(&mut context::ServerContext::current(), 7).await, Ok(7)); + assert_matches!( + serve.serve(&mut context::ServerContext::current(), 7).await, + Ok(7) + ); } #[tokio::test] @@ -1081,10 +1110,13 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::ServerContext, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - }.boxed()); + let serve = serve(move |ctx: &mut context::ServerContext, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() + }); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::ServerContext::current(); ctx.deadline = some_other_time; @@ -1117,7 +1149,11 @@ mod tests { } } impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::ServerContext, _: &mut Result) { + async fn after( + &mut self, + _: &mut context::ServerContext, + _: &mut Result, + ) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } @@ -1136,7 +1172,9 @@ mod tests { let deadline_hook = serve.before(|_: &mut context::ServerContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(&mut context::ServerContext::current(), 7).await; + let resp: Result = deadline_hook + .serve(&mut context::ServerContext::current(), 7) + .await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1341,7 +1379,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index d9e676ca4..64d65807f 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -15,7 +15,11 @@ pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result); + async fn after( + &mut self, + ctx: &mut context::ServerContext, + resp: &mut Result, + ); } impl AfterRequest for F @@ -23,7 +27,11 @@ where F: FnMut(&mut context::ServerContext, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result) { + async fn after( + &mut self, + ctx: &mut context::ServerContext, + resp: &mut Result, + ) { self(ctx, resp).await } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 4a1b2ad8a..1f647227f 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -19,7 +19,11 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError>; + async fn before( + &mut self, + ctx: &mut context::ServerContext, + req: &Req, + ) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. @@ -59,7 +63,11 @@ where F: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut context::ServerContext, + req: &Req, + ) -> Result<(), ServerError> { self(ctx, req).await } } @@ -141,7 +149,11 @@ pub struct BeforeRequestNil; impl, Rest: BeforeRequest> BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut context::ServerContext, + req: &Req, + ) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index af37427af..dff0abe0b 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,7 +46,11 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { + async fn serve( + self, + ctx: &mut context::ServerContext, + req: Req, + ) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index ac2201933..709167751 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -13,7 +13,7 @@ use crate::{ use futures::{Sink, Stream, task::*}; use pin_project::pin_project; use std::{collections::VecDeque, io, pin::Pin, time::Instant}; -use tracing::{Span}; +use tracing::Span; #[pin_project] pub(crate) struct FakeChannel { diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index a319ef046..9607b5ef0 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,11 +6,11 @@ //! Transports backed by in-memory channels. -use futures::{Sink, Stream, task::*, SinkExt, TryStreamExt}; +use crate::Transport; +use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; use tokio::sync::mpsc; -use crate::Transport; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] @@ -42,10 +42,14 @@ pub fn unbounded() -> ( /// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's /// [`Sink`]. -pub fn unbounded_mapped(mut f: F, mut g: G) -> ( +pub fn unbounded_mapped( + mut f: F, + mut g: G, +) -> ( impl Transport, impl Transport, -) where +) +where F: FnMut(ClientSinkItem) -> SerializedSinkItem + Send + 'static, G: FnMut(SerializedSinkItem) -> ServerSinkItem + Send + 'static, { @@ -179,15 +183,21 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::{ServerError, client::{self, RpcError}, context, server::{BaseChannel, incoming::Incoming, serve}, transport::{ - self, - channel::{Channel, UnboundedChannel}, - }, ClientMessage}; + use crate::context::{ClientContext, ServerContext, SharedContext}; + use crate::{ + ClientMessage, ServerError, + client::{self, RpcError}, + context, + server::{BaseChannel, incoming::Incoming, serve}, + transport::{ + self, + channel::{Channel, UnboundedChannel}, + }, + }; use assert_matches::assert_matches; use futures::{prelude::*, stream}; use std::io; use tracing::trace; - use crate::context::{ClientContext, ServerContext, SharedContext}; #[test] fn ensure_is_transport() { @@ -226,8 +236,12 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(&mut context::ClientContext::current(), "123".into()).await; - let response2 = client.call(&mut context::ClientContext::current(), "abc".into()).await; + let response1 = client + .call(&mut context::ClientContext::current(), "123".into()) + .await; + let response2 = client + .call(&mut context::ClientContext::current(), "abc".into()) + .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 73f6656d9..0ee12183d 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,11 +1,11 @@ use futures::prelude::*; -use tarpc::{serde_transport, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, }; use tokio_serde::formats::Json; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; #[tarpc::derive_serde] #[derive(Debug, PartialEq, Eq)] @@ -44,14 +44,22 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); - let transport = serde_transport::tcp::connect(addr, Json::default).await?.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = serde_transport::tcp::connect(addr, Json::default) + .await? + .with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 30e4c0743..b6ba72026 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,9 +4,16 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::{client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, transport, transport::channel, ClientMessage}; -use tokio::join; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, + client::{self}, + context, + server::{BaseChannel, Channel, incoming::Incoming}, + transport, + transport::channel, +}; +use tokio::join; #[tarpc_plugins::service] trait Service { @@ -43,7 +50,13 @@ async fn sequential() { })) .for_each(|response| response), ); - assert_eq!(client.call(&mut context::ClientContext::current(), 1).await.unwrap(), 2); + assert_eq!( + client + .call(&mut context::ClientContext::current(), 1) + .await + .unwrap(), + 2 + ); } #[tokio::test] @@ -71,7 +84,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { |msg: ClientMessage| msg.map_context(ServerContext::new), ); - // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { @@ -109,7 +121,13 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) + .map(|t| { + t.map_ok( + |msg: tarpc::ClientMessage| { + msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) + }, + ) + }) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -117,10 +135,19 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ); let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(&mut context::ClientContext::current(), 1, 2).await, Ok(3)); + assert_matches!( + client + .add(&mut context::ClientContext::current(), 1, 2) + .await, + Ok(3) + ); assert_matches!( client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." @@ -143,7 +170,13 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) + .map(|t| { + t.map_ok( + |msg: tarpc::ClientMessage| { + msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) + }, + ) + }) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -151,12 +184,20 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ); let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(&mut context::ClientContext::current(), 1, 2).await; - let res2 = client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await; + let res1 = client + .add(&mut context::ClientContext::current(), 1, 2) + .await; + let res2 = client + .hey(&mut context::ClientContext::current(), "Tim".to_string()) + .await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -181,7 +222,6 @@ async fn concurrent() -> anyhow::Result<()> { .for_each(spawn), ); - let client = ServiceClient::new(client::Config::default(), tx).spawn(); let mut context = context::ClientContext::current(); @@ -295,12 +335,18 @@ async fn counter() -> anyhow::Result<()> { while let Some(Ok(request)) = requests.next().await { request.execute(counter.serve()).await; - }; + } }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(1)); - assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(2)); + assert_matches!( + client.count(&mut context::ClientContext::current()).await, + Ok(1) + ); + assert_matches!( + client.count(&mut context::ClientContext::current()).await, + Ok(2) + ); Ok(()) } From 15b84e4f14ddcafffe3a375b428a9a528a85076b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 21:07:59 +0100 Subject: [PATCH 07/23] run cargo clippy --- example-service/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 7e29da291..c1cf618b9 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -65,7 +65,7 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())) .map(|t| { t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) + msg.map_context(ServerContext::new) }) }) .map(server::BaseChannel::with_defaults) From 117ae5713324a8b2887c80388294e2cd76df17d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Mon, 24 Nov 2025 17:54:56 +0100 Subject: [PATCH 08/23] simplify api --- example-service/src/client.rs | 5 +-- example-service/src/server.rs | 7 +--- plugins/src/lib.rs | 9 +---- tarpc/examples/compression.rs | 9 ++--- tarpc/examples/custom_transport.rs | 8 ++-- tarpc/examples/pubsub.rs | 24 ++++-------- tarpc/examples/readme.rs | 6 +-- tarpc/examples/tls_over_tcp.rs | 9 ++--- tarpc/examples/tracing.rs | 23 ++++-------- tarpc/src/lib.rs | 6 +-- tarpc/src/server.rs | 21 ++--------- tarpc/src/server/incoming.rs | 5 +-- tarpc/src/transport/channel.rs | 39 ++++++++++++++----- tarpc/tests/dataservice.rs | 11 ++---- tarpc/tests/service_functional.rs | 60 ++++++------------------------ 15 files changed, 81 insertions(+), 161 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index e425c9eb2..e2c327dfb 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -9,6 +9,7 @@ use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; use tarpc::context::ClientContext; +use tarpc::transport::channel::map_client_context_to_shared; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -33,9 +34,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport .await? - .with(|msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + .with(|msg| future::ok(map_client_context_to_shared(msg))); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index c1cf618b9..7871c5f86 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -22,6 +22,7 @@ use tarpc::{ tokio_serde::formats::Json, }; use tokio::time; +use tarpc::transport::channel::map_shared_context_to_server; #[derive(Parser)] struct Flags { @@ -63,11 +64,7 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(ServerContext::new) - }) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index bc52cf849..21e3cd35b 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -393,14 +393,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// let resp = CalculatorResponse::Add(12); /// /// // This could be any transport. -/// let (client_side, server_side) = transport::channel::unbounded(); -/// -/// let client_side = client_side.with(|msg: tarpc::ClientMessage| async move { -/// Ok(msg.map_context(|ctx| ctx.shared_context)) -/// }); -/// let server_side = server_side.map_ok(|msg: tarpc::ClientMessage| -/// msg.map_context(tarpc::context::ServerContext::new) -/// ); +/// let (client_side, server_side) = transport::channel::unbounded_for_client_server_context(); /// /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 46300999e..4e09d625e 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -126,9 +127,7 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = transport.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }); + let transport = transport.map_ok(map_shared_context_to_server); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) @@ -137,9 +136,7 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = transport.with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index a1b1e4410..350828743 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -9,6 +9,7 @@ use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ClientMessage, serde_transport as transport}; use tokio::net::{UnixListener, UnixStream}; @@ -39,8 +40,7 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = transport - .map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); + let transport = transport.map_ok(map_shared_context_to_server); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -51,9 +51,7 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = transport.with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 16195ef3f..6a6ce5723 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -49,6 +49,7 @@ use std::{ }; use subscriber::Subscriber as _; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -105,9 +106,7 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = publisher.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }); + let publisher = publisher.map_ok(map_shared_context_to_server); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -168,9 +167,7 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = publisher.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }); + let publisher = publisher.map_ok(map_shared_context_to_server); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -191,12 +188,7 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = conn.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ); - + let conn = conn.with(|msg| future::ok(map_client_context_to_shared(msg))); let tarpc::client::NewClient { client: subscriber, dispatch, @@ -362,11 +354,9 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default).await?.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ), + tcp::connect(addrs.publisher, Json::default) + .await? + .with(|msg| future::ok(map_client_context_to_shared(msg))), ) .spawn(); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 44ae497ff..b20d4ab91 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -36,10 +36,8 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (client_transport, server_transport) = + transport::channel::unbounded_for_client_server_context(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index bac4a8048..56583b05c 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -15,6 +15,7 @@ use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -114,9 +115,7 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(|c: tarpc::ClientMessage| { - c.map_context(|ctx| ServerContext::new(ctx)) - }); + let transport = transport.map_ok(map_shared_context_to_server); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -146,9 +145,7 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = transport.with(|msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 52b068bc8..acf631be8 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -20,6 +20,7 @@ use std::{ }, }; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -175,17 +176,12 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); - let map_context = |msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }; + let map_context = + |msg: ClientMessage| future::ok(map_client_context_to_shared(msg)); let add_client = add::AddClient::from(make_stub([ tarpc::serde_transport::tcp::connect(addr1, Json::default) @@ -199,20 +195,15 @@ async fn main() -> anyhow::Result<()> { let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }) - }); + .map(|t| t.map_ok(map_shared_context_to_server)); let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = to_double_server.with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let to_double_server = + to_double_server.with(|msg| future::ok(map_client_context_to_shared(msg))); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index cf3423eb5..6f7e08fc0 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -170,11 +170,7 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let (client_transport, server_transport) = channel::unbounded_mapped( -//! |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), -//! |msg: ClientMessage| msg.map_context(ServerContext::new), -//! ); -//! +//! let (client_transport, server_transport) = channel::unbounded_for_client_server_context(); //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 559d80ba8..87ac89681 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -363,10 +363,7 @@ where /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -410,10 +407,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( @@ -764,11 +758,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); - /// + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( @@ -874,10 +864,7 @@ impl InFlightRequest { /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index ad91f0c19..1868cbe47 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -59,10 +59,7 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded_mapped( -/// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), -/// |msg: ClientMessage| msg.map_context(ServerContext::new), -/// ); +/// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 9607b5ef0..65e987d02 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,7 +6,8 @@ //! Transports backed by in-memory channels. -use crate::Transport; +use crate::context::{ClientContext, ServerContext, SharedContext}; +use crate::{ClientMessage, Transport}; use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; @@ -50,8 +51,8 @@ pub fn unbounded_mapped, ) where - F: FnMut(ClientSinkItem) -> SerializedSinkItem + Send + 'static, - G: FnMut(SerializedSinkItem) -> ServerSinkItem + Send + 'static, + F: FnMut(ClientSinkItem) -> SerializedSinkItem, + G: FnMut(SerializedSinkItem) -> ServerSinkItem, { let (client, server) = unbounded(); @@ -61,6 +62,29 @@ where (client, server) } +/// Convenience functino to return two mapped unbounded channel peers for a basechannel and a client implementation. Each [`Stream`] yields items sent through the other's +/// [`Sink`]. +pub fn unbounded_for_client_server_context() -> ( + impl Transport, Resp>, + impl Transport>, +) { + unbounded_mapped(map_client_context_to_shared, map_shared_context_to_server) +} + +/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. +pub fn map_client_context_to_shared( + msg: ClientMessage, +) -> ClientMessage { + msg.map_context(|ctx| ctx.shared_context) +} + +/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. +pub fn map_shared_context_to_server( + msg: ClientMessage, +) -> ClientMessage { + msg.map_context(ServerContext::new) +} + /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] @@ -183,9 +207,8 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::context::{ClientContext, ServerContext, SharedContext}; use crate::{ - ClientMessage, ServerError, + ServerError, client::{self, RpcError}, context, server::{BaseChannel, incoming::Incoming, serve}, @@ -210,10 +233,8 @@ mod tests { async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (client_channel, server_channel) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (client_channel, server_channel) = + transport::channel::unbounded_for_client_server_context(); tokio::spawn( stream::once(future::ready(server_channel)) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 0ee12183d..1ac04af13 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,6 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, @@ -44,11 +45,7 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) @@ -57,9 +54,7 @@ async fn test_call() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default) .await? - .with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + .with(|msg| future::ok(map_client_context_to_shared(msg))); let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index b6ba72026..ebebef660 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -5,6 +5,7 @@ use futures::{ }; use std::time::{Duration, Instant}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, client::{self}, @@ -36,10 +37,7 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); @@ -79,10 +77,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -121,13 +116,7 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| { - t.map_ok( - |msg: tarpc::ClientMessage| { - msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) - }, - ) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -135,11 +124,7 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = transport.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!( @@ -170,13 +155,7 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| { - t.map_ok( - |msg: tarpc::ClientMessage| { - msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) - }, - ) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -184,11 +163,8 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = transport.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail @@ -209,10 +185,7 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); tokio::spawn( stream::once(ready(rx)) @@ -242,10 +215,7 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); tokio::spawn( stream::once(ready(rx)) @@ -282,10 +252,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); tokio::spawn( BaseChannel::with_defaults(rx) @@ -324,10 +291,7 @@ async fn counter() -> anyhow::Result<()> { } } - let (tx, rx) = channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = channel::unbounded_for_client_server_context(); tokio::task::spawn(async move { let mut requests = BaseChannel::with_defaults(rx).requests(); From b2eb13b72a7c80d43acc1fee40fe6eadbdec9185 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 25 Nov 2025 13:02:33 +0100 Subject: [PATCH 09/23] allow transport to access server context on response as well --- example-service/src/client.rs | 6 +- example-service/src/server.rs | 9 +- plugins/src/lib.rs | 2 +- tarpc/examples/compression.rs | 9 +- tarpc/examples/custom_transport.rs | 8 +- tarpc/examples/pubsub.rs | 12 +- tarpc/examples/tls_over_tcp.rs | 6 +- tarpc/examples/tracing.rs | 26 ++--- tarpc/src/client.rs | 27 +++-- tarpc/src/lib.rs | 18 ++- tarpc/src/server.rs | 77 +++++++----- .../src/server/limits/requests_per_channel.rs | 27 +++-- tarpc/src/server/testing.rs | 18 ++- tarpc/src/transport/channel.rs | 110 ++++++++++++++++-- tarpc/tests/dataservice.rs | 9 +- tarpc/tests/service_functional.rs | 11 +- 16 files changed, 256 insertions(+), 119 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index e2c327dfb..71c9704ea 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -9,7 +9,7 @@ use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; use tarpc::context::ClientContext; -use tarpc::transport::channel::map_client_context_to_shared; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -32,9 +32,7 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = transport - .await? - .with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport.await?); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 7871c5f86..fe61904b9 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -11,18 +11,19 @@ use rand::{ thread_rng, }; use service::{World, init_tracing}; +use std::ops::Deref; use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; use tarpc::context::{ServerContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_server}; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, }; use tokio::time; -use tarpc::transport::channel::map_shared_context_to_server; #[derive(Parser)] struct Flags { @@ -64,14 +65,14 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().get_ref().get_ref().get_ref().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); + let server = HelloServer(channel.transport().get_ref().get_ref().get_ref().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 21e3cd35b..71d7d3c80 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -739,7 +739,7 @@ impl ServiceGenerator<'_> { ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<#response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<::tarpc::context::ClientContext, #response_ident>> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 4e09d625e..0801ce9f4 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,10 +9,9 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ - ClientMessage, client, context, + client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, @@ -127,7 +126,7 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = transport.map_ok(map_shared_context_to_server); + let transport = map_transport_to_server(transport); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) @@ -136,7 +135,7 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 350828743..7c23a1fa7 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,11 +6,11 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; -use tarpc::{ClientMessage, serde_transport as transport}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -40,7 +40,7 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(map_shared_context_to_server); + let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -51,7 +51,7 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 6a6ce5723..8094c490d 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -49,7 +49,7 @@ use std::{ }; use subscriber::Subscriber as _; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -106,7 +106,7 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = publisher.map_ok(map_shared_context_to_server); + let publisher = map_transport_to_server(publisher); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -167,7 +167,7 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = publisher.map_ok(map_shared_context_to_server); + let publisher = map_transport_to_server(publisher); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -188,7 +188,7 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = conn.with(|msg| future::ok(map_client_context_to_shared(msg))); + let conn = map_transport_to_client(conn); let tarpc::client::NewClient { client: subscriber, dispatch, @@ -354,9 +354,7 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default) - .await? - .with(|msg| future::ok(map_client_context_to_shared(msg))), + map_transport_to_client(tcp::connect(addrs.publisher, Json::default).await?), ) .spawn(); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 56583b05c..2d90650a5 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -15,7 +15,7 @@ use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -115,7 +115,7 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(map_shared_context_to_server); + let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -145,7 +145,7 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index acf631be8..0930aae1d 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -20,7 +20,7 @@ use std::{ }, }; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -126,8 +126,10 @@ where } fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; - N], + backends: [impl Transport>, Response> + + Send + + Sync + + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp>>, @@ -176,34 +178,26 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); - let map_context = - |msg: ClientMessage| future::ok(map_client_context_to_shared(msg)); - let add_client = add::AddClient::from(make_stub([ - tarpc::serde_transport::tcp::connect(addr1, Json::default) - .await? - .with(map_context), - tarpc::serde_transport::tcp::connect(addr2, Json::default) - .await? - .with(map_context), + map_transport_to_client(tarpc::serde_transport::tcp::connect(addr1, Json::default).await?), + map_transport_to_client(tarpc::serde_transport::tcp::connect(addr2, Json::default).await?), ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(map_shared_context_to_server)); + .map(map_transport_to_server); let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = - to_double_server.with(|msg| future::ok(map_client_context_to_shared(msg))); + let to_double_server = map_transport_to_client(to_double_server); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ebcb69db1..125f3ad4a 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -244,7 +244,7 @@ pub fn new( transport: C, ) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -292,7 +292,7 @@ pub struct RequestDispatch { impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, @@ -577,7 +577,7 @@ where } /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, response.message.map_err(RpcError::Server), @@ -657,7 +657,7 @@ where impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { type Output = Result<(), ChannelError>; @@ -746,6 +746,7 @@ mod tests { server_channel .send(Response { request_id: 0, + context: ClientContext::current(), message: Ok("Resp".into()), }) .await @@ -775,6 +776,7 @@ mod tests { let (tx, mut response) = oneshot::channel(); tx.send(Ok(Response { request_id: 0, + context: ClientContext::current(), message: Ok("well done"), })) .unwrap(); @@ -825,6 +827,7 @@ mod tests { &mut server_channel, Response { request_id: 0, + context: ClientContext::current(), message: Ok("hello".into()), }, ) @@ -1063,7 +1066,7 @@ mod tests { } impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; + type Item = Result, TransportError>; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { if matches!(self.0, TransportError::Read) { Poll::Ready(Some(Err(self.0))) @@ -1079,12 +1082,15 @@ mod tests { RequestDispatch< String, String, - UnboundedChannel, ClientMessage>, + UnboundedChannel< + Response, + ClientMessage, + >, >, >, >, Channel, - UnboundedChannel, Response>, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1162,8 +1168,11 @@ mod tests { } async fn send_response( - channel: &mut UnboundedChannel, Response>, - response: Response, + channel: &mut UnboundedChannel< + ClientMessage, + Response, + >, + response: Response, ) { channel.send(response).await.unwrap(); } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 6f7e08fc0..565fe9f89 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -392,13 +392,29 @@ impl RequestName for u64 { /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Response { +pub struct Response { /// The ID of the request being responded to. pub request_id: u64, + /// Trace context, deadline, and other cross-cutting concerns. + pub context: Ctx, /// The response body, or an error if the request failed. pub message: Result, } +impl Response { + /// Creates a modified Response by mapping the context using the provided function. + pub fn map_context(self, f: F) -> Response + where + F: FnOnce(Ctx) -> Ctx2, + { + Response { + request_id: self.request_id, + context: f(self.context), + message: self.message, + } + } +} + /// An error indicating the server aborted the request early, e.g., due to request throttling. #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] #[error("{kind:?}: {detail}")] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 87ac89681..649f21022 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -61,7 +61,7 @@ impl Config { /// Returns a channel backed by `transport` and configured with `self`. pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { BaseChannel::new(self, transport) } @@ -165,7 +165,7 @@ pub struct BaseChannel { impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -304,7 +304,10 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport::Resp>, TrackedRequest<::Req>>, + Self: Transport< + Response::Resp>, + TrackedRequest<::Req>, + >, { /// Type of request item. type Req; @@ -378,7 +381,7 @@ where /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` - fn requests(self) -> Requests + fn requests(self) -> Requests where Self: Sized, { @@ -433,7 +436,7 @@ where impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Item = Result, ChannelError>; @@ -538,9 +541,9 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, { type Error = ChannelError; @@ -552,7 +555,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -593,7 +599,7 @@ impl AsRef for BaseChannel { impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; @@ -615,19 +621,19 @@ where /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so /// it must be continually polled to ensure progress. #[pin_project] -pub struct Requests +pub struct Requests where C: Channel, { #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender>, } -impl Requests +impl Requests where C: Channel, { @@ -644,7 +650,7 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } @@ -716,7 +722,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -789,7 +795,7 @@ where } } -impl fmt::Debug for Requests +impl fmt::Debug for Requests where C: Channel, { @@ -825,7 +831,7 @@ pub struct InFlightRequest { abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender>, } impl InFlightRequest { @@ -904,6 +910,7 @@ impl InFlightRequest { tracing::debug!("CompleteRequest"); let response = Response { request_id, + context, message, }; let _ = response_tx.send(response).await; @@ -927,7 +934,7 @@ fn print_err(e: &(dyn Error + 'static)) -> String { .join(": ") } -impl Stream for Requests +impl Stream for Requests where C: Channel, { @@ -1002,11 +1009,14 @@ mod tests { BaseChannel< Req, Resp, - UnboundedChannel, Response>, + UnboundedChannel< + ClientMessage, + Response, + >, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -1016,15 +1026,19 @@ mod tests { Pin< Box< Requests< + ServerContext, BaseChannel< Req, Resp, - UnboundedChannel, Response>, + UnboundedChannel< + ClientMessage, + Response, + >, >, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1039,15 +1053,19 @@ mod tests { Pin< Box< Requests< + ServerContext, BaseChannel< Req, Resp, - channel::Channel, Response>, + channel::Channel< + ClientMessage, + Response, + >, >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1322,7 +1340,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: ServerContext::current(), message: (), }) .unwrap(); @@ -1331,6 +1349,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(()), }) .unwrap(); @@ -1398,6 +1417,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(()), }) .unwrap(); @@ -1409,6 +1429,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: ServerContext::current(), message: Ok(()), }) .await @@ -1419,7 +1440,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::ServerContext::current(), + context: ServerContext::current(), message: (), }) .unwrap(); @@ -1449,6 +1470,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(()), }) .unwrap(); @@ -1459,7 +1481,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::ServerContext::current(), + context: ServerContext::current(), message: (), }) .unwrap(); @@ -1469,6 +1491,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: ServerContext::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index bd9c103b0..395ded512 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use crate::context::ServerContext; use crate::{ Response, ServerError, server::{Channel, Config}, @@ -67,6 +68,7 @@ where self.as_mut().start_send(Response { request_id: r.request.id, + context: r.request.context, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: "server throttled the request.".into(), @@ -80,7 +82,7 @@ where } } -impl Sink::Resp>> for MaxRequests +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -92,7 +94,7 @@ where fn start_send( self: Pin<&mut Self>, - item: Response<::Resp>, + item: Response::Resp>, ) -> Result<(), Self::Error> { self.project().inner.start_send(item) } @@ -268,7 +270,8 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> { + -> PendingSink>, Response> + { PendingSink { ghost: PhantomData } } } @@ -293,7 +296,9 @@ mod tests { Poll::Pending } } - impl Channel for PendingSink>, Response> { + impl Channel + for PendingSink>, Response> + { type Req = Req; type Resp = Resp; type Transport = (); @@ -326,16 +331,16 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(1), }) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 0); - assert_eq!( - throttler.inner.sink.front(), - Some(&Response { - request_id: 0, - message: Ok(1), - }) - ); + + let result = throttler.inner.sink.front(); + + assert_eq!(result.map(|r| r.request_id), Some(0)); + + assert_eq!(result.map(|r| &r.message), Some(&Ok(1))); } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 709167751..63b65d697 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use crate::context::ServerContext; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -38,14 +39,19 @@ where } } -impl Sink> for FakeChannel> { +impl Sink> + for FakeChannel> +{ type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project().sink.poll_ready(cx).map_err(|e| match e {}) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -65,7 +71,8 @@ impl Sink> for FakeChannel> { } } -impl Channel for FakeChannel>, Response> +impl Channel + for FakeChannel>, Response> where Req: Unpin, { @@ -86,7 +93,7 @@ where } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); @@ -111,7 +118,8 @@ impl FakeChannel>, Response> { } impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> { + pub fn default() + -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 65e987d02..7615d8fe1 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -7,7 +7,9 @@ //! Transports backed by in-memory channels. use crate::context::{ClientContext, ServerContext, SharedContext}; -use crate::{ClientMessage, Transport}; +use crate::{ClientMessage, Response, Transport}; +use futures::future::{Ready}; +use futures::sink::With; use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; @@ -43,21 +45,40 @@ pub fn unbounded() -> ( /// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's /// [`Sink`]. -pub fn unbounded_mapped( +pub fn unbounded_mapped< + SerializedSinkItem, + SerializedItem, + ClientSinkItem, + ServerSinkItem, + ClientItem, + ServerItem, + F, + G, + H, + I, +>( mut f: F, mut g: G, + mut h: H, + mut i: I, ) -> ( - impl Transport, - impl Transport, + impl Transport, + impl Transport, ) where F: FnMut(ClientSinkItem) -> SerializedSinkItem, G: FnMut(SerializedSinkItem) -> ServerSinkItem, + H: FnMut(SerializedItem) -> ClientItem, + I: FnMut(ServerItem) -> SerializedItem, { let (client, server) = unbounded(); - let client = client.with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))); - let server = server.map_ok(move |msg: SerializedSinkItem| g(msg)); + let client = client + .with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))) + .map_ok(move |msg: SerializedItem| h(msg)); + let server = server + .map_ok(move |msg: SerializedSinkItem| g(msg)) + .with(move |msg: ServerItem| future::ready(Ok(i(msg)))); (client, server) } @@ -65,26 +86,93 @@ where /// Convenience functino to return two mapped unbounded channel peers for a basechannel and a client implementation. Each [`Stream`] yields items sent through the other's /// [`Sink`]. pub fn unbounded_for_client_server_context() -> ( - impl Transport, Resp>, - impl Transport>, + impl Transport, Response>, + impl Transport, ClientMessage>, ) { - unbounded_mapped(map_client_context_to_shared, map_shared_context_to_server) + unbounded_mapped( + map_req_client_context_to_shared, + map_req_shared_context_to_server, + map_resp_shared_context_to_client, + map_resp_server_context_to_shared, + ) } /// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. -pub fn map_client_context_to_shared( +fn map_req_client_context_to_shared( msg: ClientMessage, ) -> ClientMessage { msg.map_context(|ctx| ctx.shared_context) } /// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. -pub fn map_shared_context_to_server( +fn map_req_shared_context_to_server( msg: ClientMessage, ) -> ClientMessage { msg.map_context(ServerContext::new) } +/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. +fn map_resp_server_context_to_shared( + resp: Response, +) -> Response { + resp.map_context(|ctx| ctx.shared_context) +} + +/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. +fn map_resp_shared_context_to_client( + msg: Response, +) -> Response { + msg.map_context(ClientContext::new) +} + +/// TODO: document +/// Yuck, but impl trait will loose our ability to do t.as_ref() +pub fn map_transport_to_client( + t: T, +) -> futures::stream::MapOk< + With< + T, + ClientMessage, + ClientMessage, + Ready, E>>, + fn(ClientMessage) -> Ready, E>>, + >, + fn(Response) -> Response, +> +where + T: Transport, Response>, + E: From +{ + let f: fn(ClientMessage) -> Ready, E>> = |resp| futures::future::ok(map_req_client_context_to_shared(resp)); + + t.with(f).map_ok(map_resp_shared_context_to_client) +} + +/// TODO: document +/// +/// Yuck, but impl trait will loose our ability to do t.as_ref() +pub fn map_transport_to_server( + t: T, +) -> futures::stream::MapOk< + With< + T, + Response, + Response, + Ready, E>>, + fn(Response) -> Ready, E>>, + >, + fn(ClientMessage) -> ClientMessage, +> +where + T: Transport, ClientMessage>, + E: From +{ + let f: fn(Response) -> Ready, E>> = |resp| futures::future::ok(map_resp_server_context_to_shared(resp)); + + t.with(f) + .map_ok(map_req_shared_context_to_server) +} + /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 1ac04af13..54fadf77d 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,6 +1,6 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, @@ -45,16 +45,15 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); - let transport = serde_transport::tcp::connect(addr, Json::default) - .await? - .with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = serde_transport::tcp::connect(addr, Json::default).await?; + let transport = map_transport_to_client(transport); let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index ebebef660..b65a66104 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,8 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ ClientMessage, client::{self}, @@ -116,7 +115,7 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -124,7 +123,7 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!( @@ -155,7 +154,7 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -163,7 +162,7 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); From 54b8fe8a0fd6b32e305bed89d4dfa529e384ea6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 25 Nov 2025 16:05:33 +0100 Subject: [PATCH 10/23] allow server to mutate shared context --- tarpc/src/client.rs | 90 +++++++++++++------------- tarpc/src/client/in_flight_requests.rs | 19 +++--- 2 files changed, 53 insertions(+), 56 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 125f3ad4a..ee85d842d 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -166,14 +166,19 @@ where }) .await .map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?; - response_guard.response().await + + let (response_ctx, r) = response_guard.response().await?; + + ctx.shared_context = response_ctx.shared_context; + + Ok(r) } } /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -201,7 +206,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result { + async fn response(mut self) -> Result<(ClientContext, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -280,7 +285,7 @@ pub struct RequestDispatch { /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, /// Produces errors that can be sent in response to any unprocessed requests at the time @@ -296,7 +301,7 @@ where { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -522,12 +527,10 @@ where let trace_context = ctx.trace_context; let deadline = ctx.deadline; - let client_context = context::ClientContext::new(ctx); - let request = ClientMessage::Request(Request { id: request_id, message: request, - context: client_context, + context: ClientContext::new(ctx), }); self.in_flight_requests() @@ -580,7 +583,7 @@ where fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server), + response.message.map_err(RpcError::Server).map(|m| (response.context, m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -688,11 +691,11 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContext, + pub ctx: context::SharedContextg, ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -752,7 +755,7 @@ mod tests { .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok((_, resp))) if resp == "Resp"); } #[tokio::test] @@ -774,12 +777,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok(Response { - request_id: 0, - context: ClientContext::current(), - message: Ok("well done"), - })) - .unwrap(); + tx.send(Ok((ClientContext::current(), "well done"))).unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -1116,37 +1114,11 @@ mod tests { (Box::pin(dispatch), channel, server_channel) } - async fn reserve_for_send<'a>( - channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { - let permit = channel.to_dispatch.reserve().await.unwrap(); - |request| { - let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - let request = DispatchRequest { - ctx: SharedContext::current(), - span: Span::current(), - request_id, - request: request.to_string(), - response_completion, - }; - permit.send(request); - ResponseGuard { - response, - cancellation: &channel.cancellation, - request_id, - cancel: true, - } - } - } - async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -1167,6 +1139,32 @@ mod tests { response_guard } + async fn reserve_for_send<'a>( + channel: &'a mut Channel, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { + let permit = channel.to_dispatch.reserve().await.unwrap(); + |request| { + let request_id = + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request = DispatchRequest { + ctx: SharedContext::current(), + span: Span::current(), + request_id, + request: request.to_string(), + response_completion, + }; + permit.send(request); + ResponseGuard { + response, + cancellation: &channel.cancellation, + request_id, + cancel: true, + } + } + } + async fn send_response( channel: &mut UnboundedChannel< ClientMessage, diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 7a554de27..90f60c527 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,7 +1,4 @@ -use crate::{ - trace, - util::{Compact, TimeUntil}, -}; +use crate::{trace, util::{Compact, TimeUntil}}; use fnv::FnvHashMap; use std::time::Instant; use std::{ @@ -11,6 +8,8 @@ use std::{ use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; +use crate::client::RpcError; +use crate::context::ClientContext; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -32,7 +31,7 @@ impl Default for InFlightRequests { struct RequestData { ctx: trace::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -60,7 +59,7 @@ impl InFlightRequests { ctx: trace::Context, deadline: Instant, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -78,8 +77,8 @@ impl InFlightRequests { } } - /// Removes a request without aborting. Returns true iff the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option { + /// Removes a request without aborting. Returns true if the request was found. + pub fn complete_request(&mut self, request_id: u64, result: Result<(ClientContext, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -97,7 +96,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Res + 'a, + mut result: impl FnMut() -> Result<(ClientContext, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -123,7 +122,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Res, + expired_error: impl Fn() -> Result<(ClientContext, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); From 6ded9edff5f4dd391a8cc5a3702ec08d391f1f43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 25 Nov 2025 16:10:30 +0100 Subject: [PATCH 11/23] fix typo --- tarpc/src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ee85d842d..a42c94491 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -691,7 +691,7 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContextg, ///TODO: <-- this should be a &mut ClientContext + pub ctx: context::SharedContext, ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, From 6fa12926de7dc54eb38799817c744b187e556b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 13:09:55 +0100 Subject: [PATCH 12/23] make servertransport generic, defined by the service implementation. --- example-service/src/server.rs | 1 + plugins/src/lib.rs | 17 +- plugins/tests/service.rs | 4 + tarpc/Cargo.toml | 2 + tarpc/examples/compression.rs | 2 + tarpc/examples/custom_transport.rs | 2 + tarpc/examples/pubsub.rs | 2 + tarpc/examples/readme.rs | 1 + tarpc/examples/tls_over_tcp.rs | 1 + tarpc/examples/tracing.rs | 2 + tarpc/src/client.rs | 37 +-- tarpc/src/client/in_flight_requests.rs | 12 +- tarpc/src/client/stub.rs | 14 +- tarpc/src/client/stub/load_balance.rs | 10 +- tarpc/src/client/stub/mock.rs | 13 +- tarpc/src/client/stub/retry.rs | 4 +- tarpc/src/context.rs | 43 +++- tarpc/src/lib.rs | 8 +- tarpc/src/server.rs | 219 ++++++++++-------- tarpc/src/server/incoming.rs | 2 +- tarpc/src/server/limits/channels_per_key.rs | 1 + .../src/server/limits/requests_per_channel.rs | 14 +- tarpc/src/server/request_hook.rs | 18 +- tarpc/src/server/request_hook/after.rs | 17 +- tarpc/src/server/request_hook/before.rs | 82 ++++--- .../server/request_hook/before_and_after.rs | 19 +- tarpc/src/server/testing.rs | 24 +- tarpc/tests/dataservice.rs | 1 + tarpc/tests/service_functional.rs | 12 +- 29 files changed, 347 insertions(+), 237 deletions(-) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index fe61904b9..5e176dfa2 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -38,6 +38,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { + type Context = ServerContext; async fn hello(self, _: &mut context::ServerContext, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 71d7d3c80..432b2f1c8 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -402,7 +402,8 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: &mut ServerContext, a: i32, b: i32) -> i32 { +/// type Context = ServerContext; +/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -559,7 +560,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: &mut ::tarpc::context::ServerContext, #( #args ),*) -> #output; + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; } }, ); @@ -568,6 +569,8 @@ impl ServiceGenerator<'_> { quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { + type Context: ::tarpc::context::ExtractContext<::tarpc::context::SharedContext>; + #( #rpc_fns )* /// Returns a serving function to use with @@ -578,11 +581,11 @@ impl ServiceGenerator<'_> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { } impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + where S: ::tarpc::client::stub::Stub { } } @@ -621,9 +624,9 @@ impl ServiceGenerator<'_> { { type Req = #request_ident; type Resp = #response_ident; + type ServerCtx = S::Context; - - async fn serve(self, ctx: &mut ::tarpc::context::ServerContext, req: #request_ident) + async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -787,7 +790,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::ClientContext, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ServerCtx, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 756766621..ef49b9666 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; +use tarpc::context::ServerContext; #[test] fn att_service_trait() { @@ -12,6 +13,7 @@ fn att_service_trait() { } impl Foo for () { + type Context = ServerContext; async fn two_part( self, _: &mut context::ServerContext, @@ -42,6 +44,7 @@ fn raw_idents() { } impl r#trait for () { + type Context = ServerContext; async fn r#await( self, _: &mut context::ServerContext, @@ -69,6 +72,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { + type Context = ServerContext; async fn foo(self, _: &mut context::ServerContext) {} } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 778eb0938..0a5efc137 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -61,6 +61,8 @@ tracing = { version = "0.1", default-features = false, features = [ tracing-opentelemetry = { version = "0.31.0", default-features = false } opentelemetry = { version = "0.30.0", default-features = false } opentelemetry-semantic-conventions = "0.30.0" +anymap3 = "1.0.1" +serde-value = "0.7" [dev-dependencies] assert_matches = "1.4" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 0801ce9f4..c00ffc9f3 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -16,6 +16,7 @@ use tarpc::{ server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, }; +use tarpc::context::ServerContext; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -109,6 +110,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { + type Context = ServerContext; async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}!") } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 7c23a1fa7..415cb5442 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use console_subscriber::Server; use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::serde_transport as transport; @@ -22,6 +23,7 @@ pub trait PingService { struct Service; impl PingService for Service { + type Context = ServerContext; async fn ping(self, _: &mut ServerContext) {} } #[tokio::main] diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 8094c490d..6755e49ca 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -82,6 +82,7 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { + type Context = ServerContext; async fn topics(self, _: &mut context::ServerContext) -> Vec { self.topics.clone() } @@ -271,6 +272,7 @@ impl Publisher { } impl publisher::Publisher for Publisher { + type Context = ServerContext; async fn publish(self, _: &mut context::ServerContext, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index b20d4ab91..c7e8de00b 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -25,6 +25,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { + type Context = ServerContext; async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hello, {name}!") } diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 2d90650a5..4ed3298bb 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -33,6 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { + type Context = ServerContext; async fn ping(self, _: &mut ServerContext) -> String { "🔒".to_owned() } diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 0930aae1d..f747c9d75 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -58,6 +58,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { + type Context = ServerContext; async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } @@ -72,6 +73,7 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { + type Context = ServerContext; async fn double(self, _: &mut context::ServerContext, x: i32) -> Result { self.add_client .add(&mut context::ClientContext::current(), x, x) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index a42c94491..b6763a9b2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,7 +9,7 @@ mod in_flight_requests; pub mod stub; -use crate::context::ClientContext; +use crate::context::{ClientContext, ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -125,23 +125,24 @@ where skip(self, ctx, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.deadline.time_until()), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.extract().deadline.time_until()), otel.kind = "client", otel.name = %request.name()) )] - pub async fn call( + pub async fn call>( &self, - ctx: &mut context::ClientContext, + ctx: &mut Ctx, request: Req, ) -> Result { let span = Span::current(); - ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + let mut shared_context = ctx.extract(); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled child context." ); - ctx.trace_context.new_child() + shared_context.trace_context.new_child() }); - span.record("rpc.trace_id", tracing::field::display(ctx.trace_id())); + span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id())); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -158,7 +159,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx: ctx.shared_context.clone(), + ctx: shared_context, span, request_id, request, @@ -169,7 +170,7 @@ where let (response_ctx, r) = response_guard.response().await?; - ctx.shared_context = response_ctx.shared_context; + ctx.update(response_ctx); Ok(r) } @@ -178,7 +179,7 @@ where /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -206,7 +207,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result<(ClientContext, Resp), RpcError> { + async fn response(mut self) -> Result<(SharedContext, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -583,7 +584,7 @@ where fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server).map(|m| (response.context, m)), + response.message.map_err(RpcError::Server).map(|m| (response.context.shared_context, m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -695,7 +696,7 @@ struct DispatchRequest { pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -777,7 +778,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((ClientContext::current(), "well done"))).unwrap(); + tx.send(Ok((SharedContext::current(), "well done"))).unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -1117,8 +1118,8 @@ mod tests { async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -1141,8 +1142,8 @@ mod tests { async fn reserve_for_send<'a>( channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { let permit = channel.to_dispatch.reserve().await.unwrap(); |request| { diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 90f60c527..5b648098b 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -9,7 +9,7 @@ use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; use crate::client::RpcError; -use crate::context::ClientContext; +use crate::context::{SharedContext}; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -31,7 +31,7 @@ impl Default for InFlightRequests { struct RequestData { ctx: trace::Context, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -59,7 +59,7 @@ impl InFlightRequests { ctx: trace::Context, deadline: Instant, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -78,7 +78,7 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true if the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Result<(ClientContext, Res), RpcError>) -> Option { + pub fn complete_request(&mut self, request_id: u64, result: Result<(SharedContext, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -96,7 +96,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Result<(ClientContext, Res), RpcError> + 'a, + mut result: impl FnMut() -> Result<(SharedContext, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -122,7 +122,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Result<(ClientContext, Res), RpcError>, + expired_error: impl Fn() -> Result<(SharedContext, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index b99f8e42c..6fa159dd7 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -6,6 +6,7 @@ use crate::{ context, server::Serve, }; +use crate::context::{ClientContext, ServerContext}; pub mod load_balance; pub mod retry; @@ -23,10 +24,13 @@ pub trait Stub { /// The service response type. type Resp; + ///TODO: document + type ServerCtx; + /// Calls a remote service. async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result; } @@ -37,10 +41,11 @@ where { type Req = Req; type Resp = Resp; + type ServerCtx = ClientContext; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Req, ) -> Result { Self::call(self, ctx, request).await @@ -49,13 +54,14 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; + type ServerCtx = ClientContext; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut ClientContext, req: Self::Req, ) -> Result { let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 62c8bf677..5b319c6c8 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -7,7 +7,6 @@ pub use round_robin::RoundRobin; mod round_robin { use crate::{ client::{RpcError, stub}, - context, }; use cycle::AtomicCycle; @@ -17,10 +16,11 @@ mod round_robin { { type Req = Stub::Req; type Resp = Stub::Resp; + type ServerCtx = Stub::ServerCtx; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -99,8 +99,7 @@ mod round_robin { /// the same stub. mod consistent_hash { use crate::{ - client::{RpcError, stub}, - context, + client::{RpcError, stub} }; use std::{ collections::hash_map::RandomState, @@ -116,10 +115,11 @@ mod consistent_hash { { type Req = Stub::Req; type Resp = Stub::Resp; + type ServerCtx = Stub::ServerCtx; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index bebd8fc99..9a22d101e 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -1,16 +1,17 @@ use crate::{ RequestName, ServerError, client::{RpcError, stub::Stub}, - context, }; use std::{collections::HashMap, hash::Hash, io}; +use std::marker::PhantomData; /// A mock stub that returns user-specified responses. -pub struct Mock { +pub struct Mock { responses: HashMap, + ghost: PhantomData } -impl Mock +impl Mock where Req: Eq + Hash, { @@ -18,21 +19,23 @@ where pub fn new(responses: [(Req, Resp); N]) -> Self { Self { responses: HashMap::from(responses), + ghost: PhantomData } } } -impl Stub for Mock +impl Stub for Mock where Req: Eq + Hash + RequestName, Resp: Clone, { type Req = Req; type Resp = Resp; + type ServerCtx = ServerCtx; async fn call( &self, - _: &mut context::ClientContext, + _: &mut Self::ServerCtx, request: Self::Req, ) -> Result { self.responses diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index d93daa156..2cf950aed 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -3,7 +3,6 @@ use crate::{ RequestName, client::{RpcError, stub}, - context, }; use std::sync::Arc; @@ -15,10 +14,11 @@ where { type Req = Req; type Resp = Stub::Resp; + type ServerCtx = Stub::ServerCtx; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index bbbc3721d..798044c93 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -23,7 +23,6 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. #[derive(Debug, Clone)] -#[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct SharedContext { /// When the client expects the request to be complete by. The server should cancel the request @@ -36,7 +35,25 @@ pub struct SharedContext { /// When a service handles a request by making requests itself, those requests should /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. - pub trace_context: trace::Context, + pub trace_context: trace::Context +} + +///TODO +pub trait ExtractContext { + ///TODO + fn extract(&self) -> Ctx; + ///TODO + fn update(&mut self, value: Ctx); +} + +impl ExtractContext for T where T: Clone { + fn extract(&self) -> T { + self.clone() + } + + fn update(&mut self, value: T) { + *self = value + } } /// Request context that carries request-scoped server side information like deadlines and trace information @@ -100,6 +117,28 @@ impl ClientContext { } } +impl ExtractContext for ClientContext { + fn extract(&self) -> SharedContext { + self.shared_context.clone() + } + + fn update(&mut self, value: SharedContext) { + self.shared_context = value + } +} + +impl ExtractContext for ServerContext { + fn extract(&self) -> SharedContext { + self.shared_context.clone() + } + + fn update(&mut self, value: SharedContext) { + self.shared_context = value + } +} + + + impl Deref for ClientContext { type Target = SharedContext; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 565fe9f89..0578a392f 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,8 +124,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { +//! type Context = context::ServerContext; //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: &mut context::ServerContext, name: String) -> String { +//! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -160,8 +161,9 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: &mut context::ServerContext, name: String) -> String { +//! # type Context = ServerContext; +//! # // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 649f21022..fe7440f7e 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,11 +6,10 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::context::ServerContext; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::{self, SpanExt}, + context::{SpanExt}, trace, util::TimeUntil, }; @@ -28,6 +27,7 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; +use crate::context::{ExtractContext, SharedContext}; mod in_flight_requests; pub mod request_hook; @@ -59,9 +59,10 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { BaseChannel::new(self, transport) } @@ -70,6 +71,9 @@ impl Config { /// Equivalent to a `FnOnce(Req) -> impl Future`. #[allow(async_fn_in_trait)] pub trait Serve { + ///TODO document + type ServerCtx; + /// Type of request. type Req: RequestName; @@ -79,19 +83,19 @@ pub trait Serve { /// Responds to a single request. async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut Self::ServerCtx, req: Self::Req, ) -> Result; } /// A Serve wrapper around a Fn. #[derive(Debug)] -pub struct ServeFn { +pub struct ServeFn { f: F, - data: PhantomData Resp>, + data: PhantomData<(Req, Resp, ServerCtx)>, } -impl Clone for ServeFn +impl Clone for ServeFn where F: Clone, { @@ -103,16 +107,13 @@ where } } -impl Copy for ServeFn where F: Copy {} +impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce( - &'a mut context::ServerContext, - Req, - ) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -120,18 +121,19 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, for<'a> F: FnOnce( - &'a mut context::ServerContext, + &'a mut ServerCtx, Req, ) -> Pin> + 'a + Send>>, { + type ServerCtx = ServerCtx; type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -147,7 +149,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -160,12 +162,13 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, + ghost: PhantomData<(Req, Resp, ServeCtx)>, } -impl BaseChannel +impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -211,28 +214,29 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, - ) -> Result, AlreadyExistsError> { + request: Request, + ) -> Result, AlreadyExistsError> { + let mut shared_context = request.context.extract(); let span = info_span!( "RPC", - rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()), + rpc.trace_id = %shared_context.trace_id(), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + shared_context.deadline.time_until()), otel.kind = "server", otel.name = tracing::field::Empty, ); - span.set_context(&request.context); - request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + span.set_context(&shared_context); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled \ child context." ); - request.context.trace_context.new_child() + shared_context.trace_context.new_child() }); let entered = span.enter(); tracing::debug!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( request.id, - request.context.deadline, + shared_context.deadline, span.clone(), ); match start { @@ -257,7 +261,7 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -265,9 +269,9 @@ impl fmt::Debug for BaseChannel { /// A request tracked by a [`Channel`]. #[derive(Debug)] -pub struct TrackedRequest { +pub struct TrackedRequest { /// The request sent by the client. - pub request: Request, + pub request: Request, /// A registration to abort a future when the [`Channel`] that produced this request stops /// tracking it. pub abort_registration: AbortRegistration, @@ -305,8 +309,8 @@ pub struct TrackedRequest { pub trait Channel where Self: Transport< - Response::Resp>, - TrackedRequest<::Req>, + Response::Resp>, + TrackedRequest::Req>, >, { /// Type of request item. @@ -317,6 +321,8 @@ where /// The wrapped transport. type Transport; + ///TODO document + type ServerCtx; /// Configuration of the channel. fn config(&self) -> &Config; @@ -381,7 +387,7 @@ where /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` - fn requests(self) -> Requests + fn requests(self) -> Requests where Self: Sized, { @@ -428,17 +434,18 @@ where where Self: Sized, Self::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.requests().execute(serve) } } -impl Stream for BaseChannel +impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext { - type Item = Result, ChannelError>; + type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { #[derive(Clone, Copy, Debug)] @@ -541,10 +548,11 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, + ServerCtx: ExtractContext { type Error = ChannelError; @@ -557,7 +565,7 @@ where fn start_send( mut self: Pin<&mut Self>, - response: Response, + response: Response, ) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() @@ -591,19 +599,22 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { + type Req = Req; type Resp = Resp; type Transport = T; + type ServerCtx = ServerCtx; fn config(&self) -> &Config { &self.config @@ -621,19 +632,19 @@ where /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so /// it must be continually polled to ensure progress. #[pin_project] -pub struct Requests +pub struct Requests where C: Channel, { #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender>, } -impl Requests +impl Requests where C: Channel, { @@ -650,14 +661,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, @@ -722,7 +733,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -779,7 +790,7 @@ where pub fn execute(self, serve: S) -> impl Stream> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.take_while(|result| { if let Err(e) = result { @@ -795,7 +806,7 @@ where } } -impl fmt::Debug for Requests +impl fmt::Debug for Requests where C: Channel, { @@ -826,17 +837,17 @@ impl Drop for ResponseGuard { /// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will /// be sent to the Channel to clean up associated request state. #[derive(Debug)] -pub struct InFlightRequest { - request: Request, +pub struct InFlightRequest { + request: Request, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender>, } -impl InFlightRequest { +impl InFlightRequest { /// Returns a reference to the request. - pub fn get(&self) -> &Request { + pub fn get(&self) -> &Request { &self.request } @@ -889,7 +900,7 @@ impl InFlightRequest { pub async fn execute(self, serve: S) where Req: RequestName, - S: Serve, + S: Serve, { let Self { response_tx, @@ -934,11 +945,11 @@ fn print_err(e: &(dyn Error + 'static)) -> String { .join(": ") } -impl Stream for Requests +impl Stream for Requests where C: Channel, { - type Item = Result, C::Error>; + type Item = Result, C::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -984,7 +995,6 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; - use crate::context::ServerContext; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1002,6 +1012,7 @@ mod tests { task::Poll, time::{Duration, Instant}, }; + use crate::context::{ExtractContext, SharedContext}; fn test_channel() -> ( Pin< @@ -1010,13 +1021,14 @@ mod tests { Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, + SharedContext >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -1026,19 +1038,20 @@ mod tests { Pin< Box< Requests< - ServerContext, BaseChannel< Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, + SharedContext >, + >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1053,19 +1066,19 @@ mod tests { Pin< Box< Requests< - ServerContext, BaseChannel< Req, Resp, channel::Channel< - ClientMessage, - Response, + ClientMessage, + Response, >, + SharedContext >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1075,9 +1088,9 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::ServerContext::current(), + context: context::SharedContext::current(), id: 0, message: req, }) @@ -1101,13 +1114,15 @@ mod tests { #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline { + impl BeforeRequest for SetDeadline where ServerCtx: ExtractContext { async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, _: &Req, ) -> Result<(), ServerError> { - ctx.deadline = self.0; + let mut inner = ctx.extract(); + inner.deadline = self.0; + ctx.update(inner); Ok(()) } } @@ -1115,7 +1130,7 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::ServerContext, i| { + let serve = serve(move |ctx: &mut context::SharedContext, i| { async move { assert_eq!(ctx.deadline, some_time); Ok(i) @@ -1123,7 +1138,7 @@ mod tests { .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::ServerContext::current(); + let mut ctx = context::SharedContext::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1143,20 +1158,20 @@ mod tests { } } } - impl BeforeRequest for PrintLatency { + impl BeforeRequest for PrintLatency { async fn before( &mut self, - _: &mut context::ServerContext, + _: &mut ServerCtx, _: &Req, ) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } } - impl AfterRequest for PrintLatency { + impl AfterRequest for PrintLatency { async fn after( &mut self, - _: &mut context::ServerContext, + _: &mut ServerCtx, _: &mut Result, ) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); @@ -1192,14 +1207,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: () }), Err(AlreadyExistsError) @@ -1215,7 +1230,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1223,7 +1238,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1246,7 +1261,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1275,7 +1290,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1317,7 +1332,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1340,7 +1355,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1349,7 +1364,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1408,7 +1423,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1417,7 +1432,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1429,7 +1444,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .await @@ -1440,7 +1455,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1461,7 +1476,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1470,7 +1485,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1481,7 +1496,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1491,7 +1506,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 1868cbe47..56f393b84 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -33,7 +33,7 @@ where ) -> impl Stream>> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.map(move |channel| channel.execute(serve.clone())) } diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 64b644278..3ffdfac89 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -107,6 +107,7 @@ where type Req = C::Req; type Resp = C::Resp; type Transport = C::Transport; + type ServerCtx = C::ServerCtx; fn config(&self) -> &server::Config { self.inner.config() diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 395ded512..383abb9c8 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::ServerContext; use crate::{ Response, ServerError, server::{Channel, Config}, @@ -82,7 +81,7 @@ where } } -impl Sink::Resp>> for MaxRequests +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -94,7 +93,7 @@ where fn start_send( self: Pin<&mut Self>, - item: Response::Resp>, + item: Response::Resp>, ) -> Result<(), Self::Error> { self.project().inner.start_send(item) } @@ -121,6 +120,7 @@ where type Req = ::Req; type Resp = ::Resp; type Transport = ::Transport; + type ServerCtx = ::ServerCtx; fn in_flight_requests(&self) -> usize { self.inner.in_flight_requests() @@ -190,6 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; + use crate::context::{ServerContext, SharedContext}; #[tokio::test] async fn throttler_in_flight_requests() { @@ -270,7 +271,7 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> + -> PendingSink>, Response> { PendingSink { ghost: PhantomData } } @@ -297,11 +298,12 @@ mod tests { } } impl Channel - for PendingSink>, Response> + for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = ServerContext; fn config(&self) -> &Config { unimplemented!() } @@ -331,7 +333,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(1), }) .unwrap(); diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 38b0998bf..338059f7d 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -62,9 +62,9 @@ pub trait RequestHook: Serve { /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` - fn before(self, hook: Hook) -> HookThenServe + fn before(self, hook: Hook) -> HookThenServe where - Hook: BeforeRequest, + Hook: BeforeRequest, Self: Sized, { HookThenServe::new(self, hook) @@ -107,7 +107,7 @@ pub trait RequestHook: Serve { /// ``` fn after(self, hook: Hook) -> ServeThenHook where - Hook: AfterRequest, + Hook: AfterRequest, Self: Sized, { ServeThenHook::new(self, hook) @@ -133,17 +133,17 @@ pub trait RequestHook: Serve { /// /// struct PrintLatency(Instant); /// - /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { + /// impl BeforeRequest for PrintLatency { + /// async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } /// } /// - /// impl AfterRequest for PrintLatency { + /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::ServerContext, + /// _: &mut ServerCtx, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -160,9 +160,9 @@ pub trait RequestHook: Serve { fn before_and_after( self, hook: Hook, - ) -> HookThenServeThenHook + ) -> HookThenServeThenHook where - Hook: BeforeRequest + AfterRequest, + Hook: BeforeRequest + AfterRequest, Self: Sized, { HookThenServeThenHook::new(self, hook) diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index 64d65807f..ce6319e25 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -6,30 +6,30 @@ //! Provides a hook that runs after request execution. -use crate::{ServerError, context, server::Serve}; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs after request execution. #[allow(async_fn_in_trait)] -pub trait AfterRequest { +pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. async fn after( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, resp: &mut Result, ); } -impl AfterRequest for F +impl AfterRequest for F where - F: FnMut(&mut context::ServerContext, &mut Result) -> Fut, + F: FnMut(&mut ServerCtx, &mut Result) -> Fut, Fut: Future, { async fn after( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, resp: &mut Result, ) { self(ctx, resp).await @@ -60,14 +60,15 @@ impl Clone for ServeThenHook { impl Serve for ServeThenHook where Serv: Serve, - Hook: AfterRequest, + Hook: AfterRequest, { type Req = Serv::Req; type Resp = Serv::Resp; + type ServerCtx = Serv::ServerCtx; async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut Serv::ServerCtx, req: Serv::Req, ) -> Result { let ServeThenHook { diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 1f647227f..3e2e091c8 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -6,12 +6,13 @@ //! Provides a hook that runs before request execution. -use crate::{ServerError, context, server::Serve}; +use std::marker::PhantomData; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] -pub trait BeforeRequest { +pub trait BeforeRequest { /// The function that is called before request execution. /// /// If this function returns an error, the request will not be executed and the error will be @@ -21,24 +22,24 @@ pub trait BeforeRequest { /// enforce a maximum deadline on all requests. async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: &Req, ) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. -pub trait BeforeRequestList: BeforeRequest { +pub trait BeforeRequestList: BeforeRequest { /// The hook returned by `BeforeRequestList::then`. - type Then: BeforeRequest + type Then: BeforeRequest where - Next: BeforeRequest; + Next: BeforeRequest; /// Returns a hook that, when run, runs two hooks, first `self` and then `next`. - fn then>(self, next: Next) -> Self::Then; + fn then>(self, next: Next) -> Self::Then; /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::ServerContext, &Req) -> Fut, + Next: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, >( self, @@ -51,21 +52,21 @@ pub trait BeforeRequestList: BeforeRequest { } /// The service fn returned by `BeforeRequestList::serving`. - type Serve>: Serve; + type Serve>: Serve; /// Runs the list of request hooks before execution of the given serve fn. /// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer. - fn serving>(self, serve: S) -> Self::Serve; + fn serving>(self, serve: S) -> Self::Serve; } -impl BeforeRequest for F +impl BeforeRequest for F where - F: FnMut(&mut context::ServerContext, &Req) -> Fut, + F: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, { async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: &Req, ) -> Result<(), ServerError> { self(ctx, req).await @@ -73,29 +74,36 @@ where } /// A Service function that runs a hook before request execution. -#[derive(Clone)] -pub struct HookThenServe { +pub struct HookThenServe { serve: Serv, hook: Hook, + ghost: PhantomData } -impl HookThenServe { +impl Clone for HookThenServe { + fn clone(&self) -> Self { + Self::new(self.serve.clone(), self.hook.clone()) + } +} + +impl HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook } + Self { serve, hook, ghost: PhantomData } } } -impl Serve for HookThenServe +impl Serve for HookThenServe where - Serv: Serve, - Hook: BeforeRequest, + Serv: Serve, + Hook: BeforeRequest, { + type ServerCtx = ServerCtx; type Req = Serv::Req; type Resp = Serv::Resp; async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: Self::Req, ) -> Result { let HookThenServe { @@ -129,7 +137,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::ServerContext::current(); +/// let mut context = context::SharedContext::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -146,12 +154,12 @@ pub struct BeforeRequestCons(First, Rest); #[derive(Clone, Copy)] pub struct BeforeRequestNil; -impl, Rest: BeforeRequest> BeforeRequest +impl, Rest: BeforeRequest, ServerCtx> BeforeRequest for BeforeRequestCons { async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: &Req, ) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; @@ -161,45 +169,45 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { +impl BeforeRequest for BeforeRequestNil { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { Ok(()) } } -impl, Rest: BeforeRequestList> BeforeRequestList +impl, Rest: BeforeRequestList, ServerCtx> BeforeRequestList for BeforeRequestCons { type Then = BeforeRequestCons> where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { let BeforeRequestCons(first, rest) = self; BeforeRequestCons(first, rest.then(next)) } - type Serve> = HookThenServe; + type Serve> = HookThenServe; - fn serving>(self, serve: S) -> Self::Serve { + fn serving>(self, serve: S) -> Self::Serve { HookThenServe::new(serve, self) } } -impl BeforeRequestList for BeforeRequestNil { +impl BeforeRequestList for BeforeRequestNil { type Then = BeforeRequestCons where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { BeforeRequestCons(next, BeforeRequestNil) } - type Serve> = S; + type Serve> = S; - fn serving>(self, serve: S) -> S { + fn serving>(self, serve: S) -> S { serve } } @@ -223,7 +231,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = context::ServerContext::current(); + let mut context = crate::context::SharedContext::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index dff0abe0b..080c53b21 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -7,17 +7,17 @@ //! Provides a hook that runs both before and after request execution. use super::{after::AfterRequest, before::BeforeRequest}; -use crate::{RequestName, ServerError, context, server::Serve}; +use crate::{RequestName, ServerError, server::Serve}; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. -pub struct HookThenServeThenHook { +pub struct HookThenServeThenHook { serve: Serv, hook: Hook, - fns: PhantomData<(fn(Req), fn(Resp))>, + fns: PhantomData<(Req, Resp, ServerCtx)>, } -impl HookThenServeThenHook { +impl HookThenServeThenHook { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, @@ -27,7 +27,7 @@ impl HookThenServeThenHook { } } -impl Clone for HookThenServeThenHook { +impl Clone for HookThenServeThenHook { fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -37,18 +37,19 @@ impl Clone for HookThenServeThenHook Serve for HookThenServeThenHook +impl Serve for HookThenServeThenHook where Req: RequestName, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, { type Req = Req; type Resp = Resp; + type ServerCtx = ServerCtx; async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: Req, ) -> Result { let HookThenServeThenHook { diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 63b65d697..9a941f711 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::ServerContext; +use crate::context::{SharedContext}; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -39,8 +39,8 @@ where } } -impl Sink> - for FakeChannel> +impl Sink> + for FakeChannel> { type Error = io::Error; @@ -50,7 +50,7 @@ impl Sink> fn start_send( mut self: Pin<&mut Self>, - response: Response, + response: Response, ) -> Result<(), Self::Error> { self.as_mut() .project() @@ -72,13 +72,14 @@ impl Sink> } impl Channel - for FakeChannel>, Response> + for FakeChannel>, Response> where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = SharedContext; fn config(&self) -> &Config { &self.config @@ -93,16 +94,16 @@ where } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::ServerContext::new(context::SharedContext { + context: context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), - }), + }, id, message, }, @@ -119,8 +120,13 @@ impl FakeChannel>, Response { pub fn default() - -> FakeChannel>, Response> { + -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); + + let mut x = anymap3::AnyMap::new(); + + x.entry::<&str>(); + FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 54fadf77d..05f1790d0 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -24,6 +24,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { + type Context = ServerContext; async fn get_opposite_color(self, _: &mut context::ServerContext, color: TestData) -> TestData { match color { TestData::White => TestData::Black, diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index b65a66104..ee44c58d8 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -14,6 +14,7 @@ use tarpc::{ transport::channel, }; use tokio::join; +use tarpc::context::{ServerContext}; #[tarpc_plugins::service] trait Service { @@ -25,11 +26,12 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { + type Context = ServerContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: &mut context::ServerContext, name: String) -> String { + async fn hey(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -67,7 +69,8 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: &mut context::ServerContext) { + type Context = ServerContext; + async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); } @@ -284,7 +287,8 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: &mut context::ServerContext) -> u32 { + type Context = ServerContext; + async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 } From 0045581c4f1ac30be739550e3ee8e4d67f54d81a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 13:25:41 +0100 Subject: [PATCH 13/23] remove servercontext entirely --- example-service/src/client.rs | 2 +- example-service/src/server.rs | 12 ++--- plugins/src/lib.rs | 8 +-- plugins/tests/service.rs | 22 ++++---- tarpc/examples/compression.rs | 9 ++-- tarpc/examples/custom_transport.rs | 9 ++-- tarpc/examples/pubsub.rs | 16 +++--- tarpc/examples/readme.rs | 6 +-- tarpc/examples/tls_over_tcp.rs | 9 ++-- tarpc/examples/tracing.rs | 18 +++---- tarpc/src/client/stub.rs | 9 ++-- tarpc/src/context.rs | 47 ++++------------- tarpc/src/lib.rs | 6 +-- tarpc/src/server.rs | 18 +++---- tarpc/src/server/incoming.rs | 2 +- .../src/server/limits/requests_per_channel.rs | 8 +-- tarpc/src/server/request_hook.rs | 10 ++-- tarpc/src/transport/channel.rs | 50 +++---------------- tarpc/tests/dataservice.rs | 9 ++-- tarpc/tests/service_functional.rs | 12 ++--- 20 files changed, 102 insertions(+), 180 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 71c9704ea..40402867f 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -9,7 +9,7 @@ use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; use tarpc::context::ClientContext; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 5e176dfa2..019a2d7b1 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -16,8 +16,7 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::{ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_server}; +use tarpc::context::{SharedContext}; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, @@ -38,8 +37,8 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - type Context = ServerContext; - async fn hello(self, _: &mut context::ServerContext, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; @@ -66,14 +65,13 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(map_transport_to_server) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().get_ref().get_ref().get_ref().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().get_ref().get_ref().get_ref().peer_addr().unwrap()); + let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 432b2f1c8..250ffff04 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -375,8 +375,10 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext}; -/// use futures_util::{TryStreamExt, sink::SinkExt}; +/// use tarpc::{client, transport, service, server::{self, Channel}}; +/// use futures_util::{TryStreamExt, sink::SinkExt};/// +/// +/// use tarpc::context::SharedContext; /// /// #[service] /// pub trait Calculator { @@ -402,7 +404,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// type Context = ServerContext; +/// type Context = SharedContext; /// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index ef49b9666..d8213f4d4 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; -use tarpc::context::ServerContext; +use tarpc::context::SharedContext; #[test] fn att_service_trait() { @@ -13,21 +13,21 @@ fn att_service_trait() { } impl Foo for () { - type Context = ServerContext; + type Context = SharedContext; async fn two_part( self, - _: &mut context::ServerContext, + _: &mut context::SharedContext, s: String, i: i32, ) -> (String, i32) { (s, i) } - async fn bar(self, _: &mut context::ServerContext, s: String) -> String { + async fn bar(self, _: &mut Self::Context, s: String) -> String { s } - async fn baz(self, _: &mut context::ServerContext) {} + async fn baz(self, _: &mut Self::Context) {} } } @@ -44,21 +44,21 @@ fn raw_idents() { } impl r#trait for () { - type Context = ServerContext; + type Context = SharedContext; async fn r#await( self, - _: &mut context::ServerContext, + _: &mut Self::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: &mut context::ServerContext, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut Self::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: &mut context::ServerContext) {} + async fn r#async(self, _: &mut Self::Context) {} } } @@ -72,8 +72,8 @@ fn service_with_cfg_rpc() { } impl Foo for () { - type Context = ServerContext; - async fn foo(self, _: &mut context::ServerContext) {} + type Context = SharedContext; + async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index c00ffc9f3..6a1440bd2 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,14 +9,14 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, }; -use tarpc::context::ServerContext; +use tarpc::context::SharedContext; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -110,8 +110,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = ServerContext; - async fn hello(self, _: &mut context::ServerContext, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -128,7 +128,6 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = map_transport_to_server(transport); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 415cb5442..92c723b4d 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,12 +6,12 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::context::{ClientContext, SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -23,8 +23,8 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = ServerContext; - async fn ping(self, _: &mut ServerContext) {} + type Context = SharedContext; + async fn ping(self, _: &mut Self::Context) {} } #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -42,7 +42,6 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 6755e49ca..fbe19078a 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -48,8 +48,8 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::context::{ClientContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -82,12 +82,12 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - type Context = ServerContext; - async fn topics(self, _: &mut context::ServerContext) -> Vec { + type Context = SharedContext; + async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: &mut context::ServerContext, topic: String, message: String) { + async fn receive(self, _: &mut Self::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -107,7 +107,6 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = map_transport_to_server(publisher); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -168,7 +167,6 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = map_transport_to_server(publisher); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -272,8 +270,8 @@ impl Publisher { } impl publisher::Publisher for Publisher { - type Context = ServerContext; - async fn publish(self, _: &mut context::ServerContext, topic: String, message: String) { + type Context = SharedContext; + async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c7e8de00b..db93d2e74 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::context::{ClientContext, SharedContext}; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -25,8 +25,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = ServerContext; - async fn hello(self, _: &mut context::ServerContext, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } } diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 4ed3298bb..0e00cdca8 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,12 +10,12 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::context::{ClientContext, SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -33,8 +33,8 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = ServerContext; - async fn ping(self, _: &mut ServerContext) -> String { + type Context = SharedContext; + async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } } @@ -116,7 +116,6 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index f747c9d75..b69e0c1a0 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -19,8 +19,8 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::context::{ClientContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -58,8 +58,8 @@ pub mod double { struct AddServer; impl AddService for AddServer { - type Context = ServerContext; - async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { + type Context = SharedContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } } @@ -73,8 +73,8 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - type Context = ServerContext; - async fn double(self, _: &mut context::ServerContext, x: i32) -> Result { + type Context = SharedContext; + async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client .add(&mut context::ClientContext::current(), x, x) .await @@ -180,7 +180,6 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(map_transport_to_server) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); @@ -191,9 +190,8 @@ async fn main() -> anyhow::Result<()> { let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? - .filter_map(|r| future::ready(r.ok())) - .map(map_transport_to_server); - let addr = double_listener.get_ref().get_ref().local_addr(); + .filter_map(|r| future::ready(r.ok())); + let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 6fa159dd7..9989f0577 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -3,10 +3,9 @@ use crate::{ RequestName, client::{Channel, RpcError}, - context, server::Serve, }; -use crate::context::{ClientContext, ServerContext}; +use crate::context::{ClientContext, SharedContext}; pub mod load_balance; pub mod retry; @@ -54,7 +53,7 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; @@ -64,7 +63,7 @@ where ctx: &mut ClientContext, req: Self::Req, ) -> Result { - let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); + let mut server_ctx = ctx.shared_context.clone(); let res = self .clone() @@ -72,7 +71,7 @@ where .await .map_err(RpcError::Server); - ctx.shared_context = server_ctx.shared_context; + ctx.shared_context = server_ctx; res } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 798044c93..5cc9389f1 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -56,43 +56,6 @@ impl ExtractContext for T where T: Clone { } } -/// Request context that carries request-scoped server side information like deadlines and trace information -/// as well as any server side extensions defined by the transport, hooks or service implementations. -/// It is build from the shared context sent from client to server. -/// -/// The context should not be stored directly in a server implementation, because the context will -/// be different for each request in scope. -#[derive(Debug)] -pub struct ServerContext { - /// Shared context sent from client to server which contains information used by both sides. - pub shared_context: SharedContext, -} - -impl ServerContext { - /// Creates a new ServerContext from the given SharedContext with no extensions. - pub fn new(shared_context: SharedContext) -> Self { - Self { shared_context } - } - - /// Creates a new ServerContext for the current shared context with no extensions. - pub fn current() -> Self { - Self::new(SharedContext::current()) - } -} - -impl Deref for ServerContext { - type Target = SharedContext; - - fn deref(&self) -> &Self::Target { - &self.shared_context - } -} -impl DerefMut for ServerContext { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.shared_context - } -} - /// Request context that carries request-scoped client side information like deadlines and trace information /// as well as any server side extensions defined by the transport, hooks and stubs. /// The shared part of the context is sent from client to server, while the client side extensions are only seen on the client side. @@ -103,12 +66,20 @@ impl DerefMut for ServerContext { pub struct ClientContext { /// Shared context sent from client to server which contains information used by both sides. pub shared_context: SharedContext, + + /// Client side extensions that are not seen by the server + /// XXX, YYY, and ZZZ can use this to store per-request data, and communicate with eachother. + /// Note that this is NOT sent to the server, and they will always see an empty map here. + pub client_context: anymap3::Map, } impl ClientContext { /// Creates a new ServerContext from the given SharedContext with no extensions. pub fn new(shared_context: SharedContext) -> Self { - Self { shared_context } + Self { + shared_context, + client_context: anymap3::Map::new(), + } } /// Creates a new ServerContext for the current shared context with no extensions. diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 0578a392f..e0869d9f6 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,7 +124,7 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! type Context = context::ServerContext; +//! type Context = context::SharedContext; //! // Each defined rpc generates an async fn that serves the RPC //! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") @@ -145,7 +145,7 @@ //! # use tarpc::{ //! # ClientMessage, //! # client, context, -//! # context::{ClientContext, ServerContext, SharedContext}, +//! # context::{ClientContext, SharedContext}, //! # transport::channel, //! # server::{self, Channel}, //! # }; @@ -161,7 +161,7 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # type Context = ServerContext; +//! # type Context = SharedContext; //! # // Each defined rpc generates an async fn that serves the RPC //! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index fe7440f7e..e6c395836 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -363,7 +363,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext, ServerContext}, + /// context::{ClientContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -407,7 +407,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext, ServerContext}}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext}}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -767,7 +767,7 @@ where /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; - /// use tarpc::context::{ClientContext, SharedContext, ServerContext}; + /// use tarpc::context::{ClientContext, SharedContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -872,7 +872,7 @@ impl InFlightRequest { /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext, ServerContext}, + /// context::{ClientContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -1106,7 +1106,7 @@ mod tests { async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); assert_matches!( - serve.serve(&mut context::ServerContext::current(), 7).await, + serve.serve(&mut context::SharedContext::current(), 7).await, Ok(7) ); } @@ -1178,10 +1178,10 @@ mod tests { } } - let serve = serve(move |_: &mut context::ServerContext, i| async move { Ok(i) }.boxed()); + let serve = serve(move |_: &mut context::SharedContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::ServerContext::current(), 7) + .serve(&mut context::SharedContext::current(), 7) .await?; Ok(()) } @@ -1189,11 +1189,11 @@ mod tests { #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); - let deadline_hook = serve.before(|_: &mut context::ServerContext, _: &i32| async { + let deadline_hook = serve.before(|_: &mut context::SharedContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); let resp: Result = deadline_hook - .serve(&mut context::ServerContext::current(), 7) + .serve(&mut context::SharedContext::current(), 7) .await; assert_matches!(resp, Err(_)); Ok(()) diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 56f393b84..2baa27c89 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -50,7 +50,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, -/// context::{ClientContext, ServerContext, SharedContext}, +/// context::{ClientContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 383abb9c8..deb723bda 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -190,7 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; - use crate::context::{ServerContext, SharedContext}; + use crate::context::{SharedContext}; #[tokio::test] async fn throttler_in_flight_requests() { @@ -271,7 +271,7 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> + -> PendingSink>, Response> { PendingSink { ghost: PhantomData } } @@ -298,12 +298,12 @@ mod tests { } } impl Channel - for PendingSink>, Response> + for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = ServerContext; + type ServerCtx = SharedContext; fn config(&self) -> &Config { unimplemented!() } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 338059f7d..4f3d60377 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -48,7 +48,7 @@ pub trait RequestHook: Serve { /// use std::io; /// /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) - /// .before(|_ctx: &mut context::ServerContext, req: &i32| { + /// .before(|_ctx: &mut context::SharedContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::ServerContext::current(); + /// let mut context = context::SharedContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -95,13 +95,13 @@ pub trait RequestHook: Serve { /// Ok(i + 1) /// } /// }.boxed()) - /// .after(|_ctx: &mut context::ServerContext, resp: &mut Result| { + /// .after(|_ctx: &mut context::SharedContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// let mut context = context::ServerContext::current(); + /// let mut context = context::SharedContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::ServerContext::current(); + /// let mut context = context::SharedContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 7615d8fe1..476f60738 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,13 +6,14 @@ //! Transports backed by in-memory channels. -use crate::context::{ClientContext, ServerContext, SharedContext}; +use crate::context::{ClientContext, SharedContext}; use crate::{ClientMessage, Response, Transport}; use futures::future::{Ready}; use futures::sink::With; use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; +use std::convert::identity; use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. @@ -87,13 +88,13 @@ where /// [`Sink`]. pub fn unbounded_for_client_server_context() -> ( impl Transport, Response>, - impl Transport, ClientMessage>, + impl Transport, ClientMessage>, ) { unbounded_mapped( map_req_client_context_to_shared, - map_req_shared_context_to_server, + identity, map_resp_shared_context_to_client, - map_resp_server_context_to_shared, + identity, ) } @@ -104,21 +105,7 @@ fn map_req_client_context_to_shared( msg.map_context(|ctx| ctx.shared_context) } -/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. -fn map_req_shared_context_to_server( - msg: ClientMessage, -) -> ClientMessage { - msg.map_context(ServerContext::new) -} - -/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. -fn map_resp_server_context_to_shared( - resp: Response, -) -> Response { - resp.map_context(|ctx| ctx.shared_context) -} - -/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. +/// Convenience function to map a ClientMessage with SharedContext to one with ClientContext. fn map_resp_shared_context_to_client( msg: Response, ) -> Response { @@ -148,31 +135,6 @@ where t.with(f).map_ok(map_resp_shared_context_to_client) } -/// TODO: document -/// -/// Yuck, but impl trait will loose our ability to do t.as_ref() -pub fn map_transport_to_server( - t: T, -) -> futures::stream::MapOk< - With< - T, - Response, - Response, - Ready, E>>, - fn(Response) -> Ready, E>>, - >, - fn(ClientMessage) -> ClientMessage, -> -where - T: Transport, ClientMessage>, - E: From -{ - let f: fn(Response) -> Ready, E>> = |resp| futures::future::ok(map_resp_server_context_to_shared(resp)); - - t.with(f) - .map_ok(map_req_shared_context_to_server) -} - /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 05f1790d0..1a1b5207a 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,6 +1,6 @@ use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::context::{ClientContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, @@ -24,8 +24,8 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - type Context = ServerContext; - async fn get_opposite_color(self, _: &mut context::ServerContext, color: TestData) -> TestData { + type Context = SharedContext; + async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -46,7 +46,6 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index ee44c58d8..4692b17cd 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,7 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, client::{self}, @@ -14,7 +14,7 @@ use tarpc::{ transport::channel, }; use tokio::join; -use tarpc::context::{ServerContext}; +use tarpc::context::SharedContext; #[tarpc_plugins::service] trait Service { @@ -26,7 +26,7 @@ trait Service { struct Server; impl Service for Server { - type Context = ServerContext; + type Context = SharedContext; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } @@ -69,7 +69,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - type Context = ServerContext; + type Context = SharedContext; async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); @@ -118,7 +118,6 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -157,7 +156,6 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -287,7 +285,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type Context = ServerContext; + type Context = SharedContext; async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 From 7989bc084b78e41c3a59cc3eef2135d8ee22c63d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:09:34 +0100 Subject: [PATCH 14/23] make clientContext generic as well --- example-service/src/client.rs | 9 +- plugins/src/lib.rs | 43 +++--- tarpc/examples/compression.rs | 4 +- tarpc/examples/custom_transport.rs | 6 +- tarpc/examples/pubsub.rs | 46 +++--- tarpc/examples/readme.rs | 6 +- tarpc/examples/tls_over_tcp.rs | 6 +- tarpc/examples/tracing.rs | 32 +++-- tarpc/src/client.rs | 131 ++++++++++-------- tarpc/src/client/stub.rs | 21 +-- tarpc/src/client/stub/load_balance.rs | 14 +- tarpc/src/client/stub/mock.rs | 4 +- tarpc/src/client/stub/retry.rs | 4 +- tarpc/src/context.rs | 68 --------- tarpc/src/lib.rs | 9 +- tarpc/src/server.rs | 24 ++-- tarpc/src/server/incoming.rs | 6 +- tarpc/src/transport/channel.rs | 109 +-------------- .../compile_fail/must_use_request_dispatch.rs | 4 +- .../must_use_request_dispatch.stderr | 6 +- tarpc/tests/dataservice.rs | 9 +- tarpc/tests/service_functional.rs | 44 +++--- 22 files changed, 234 insertions(+), 371 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 40402867f..b8ff22c97 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -8,8 +8,7 @@ use clap::Parser; use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::context::ClientContext; -use tarpc::transport::channel::{map_transport_to_client}; +use tarpc::context::{SharedContext}; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -32,15 +31,15 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = map_transport_to_client(transport.await?); + let transport = transport.await?; // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. let client = WorldClient::new(client::Config::default(), transport).spawn(); let hello = async move { - let mut context = ClientContext::current(); - let mut context2 = ClientContext::current(); + let mut context = SharedContext::current(); + let mut context2 = SharedContext::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 250ffff04..cf107d0ad 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -395,7 +395,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// let resp = CalculatorResponse::Add(12); /// /// // This could be any transport. -/// let (client_side, server_side) = transport::channel::unbounded_for_client_server_context(); +/// let (client_side, server_side) = transport::channel::unbounded(); /// /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); @@ -583,11 +583,11 @@ impl ServiceGenerator<'_> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { } - impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + impl #client_stub_ident for S + where S: ::tarpc::client::stub::Stub { } } @@ -717,12 +717,19 @@ impl ServiceGenerator<'_> { quote! { #[allow(unused)] - #[derive(Clone, Debug)] + #[derive(Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](::core::future::Future). #vis struct #client_ident< - Stub = ::tarpc::client::Channel<#request_ident, #response_ident> - >(Stub); + ClientCtx, + Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx> + >(Stub, ::std::marker::PhantomData); + + impl ::std::clone::Clone for #client_ident { + fn clone(&self) -> Self { + Self(self.0.clone(), ::std::marker::PhantomData) + } + } } } @@ -736,32 +743,33 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident { + impl #client_ident { /// Returns a new client stub that sends requests over the given transport. #vis fn new(config: ::tarpc::client::Config, transport: T) -> ::tarpc::client::NewClient< Self, - ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> + ::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<::tarpc::context::ClientContext, #response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage, ::tarpc::Response> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { - client: #client_ident(new_client.client), + client: #client_ident(new_client.client, ::std::marker::PhantomData), dispatch: new_client.dispatch, } } } - impl ::core::convert::From for #client_ident + impl ::core::convert::From for #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { /// Returns a new client stub that sends requests over the given transport. fn from(stub: Stub) -> Self { - #client_ident(stub) + #client_ident::(stub, ::std::marker::PhantomData) } } @@ -784,15 +792,16 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident + impl #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ServerCtx, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 6a1440bd2..f201521ad 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,7 +9,6 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ client, context, serde_transport::tcp, @@ -136,13 +135,12 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = map_transport_to_client(transport); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( "{}", client - .hello(&mut context::ClientContext::current(), "friend".into()) + .hello(&mut context::SharedContext::current(), "friend".into()) .await? ); Ok(()) diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 92c723b4d..c9eb871ea 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,12 +6,11 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::{ClientContext, SharedContext}; +use tarpc::context::{SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -52,10 +51,9 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = map_transport_to_client(transport); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut ClientContext::current()) + .ping(&mut SharedContext::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index fbe19078a..07a93becf 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -47,9 +47,11 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; +use std::ops::Shl; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use subscriber::Subscriber as _; -use tarpc::context::{ClientContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client}; +use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -135,10 +137,19 @@ struct Subscription { topics: Vec, } -#[derive(Clone, Debug)] -struct Publisher { +#[derive(Debug)] +struct Publisher { clients: Arc>>, - subscriptions: Arc>>>, + subscriptions: Arc>>>>, +} + +impl Clone for Publisher { + fn clone(&self) -> Self { + Publisher { + clients: self.clients.clone(), + subscriptions: self.subscriptions.clone(), + } + } } struct PublisherAddrs { @@ -150,7 +161,7 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } -impl Publisher { +impl Publisher where ClientCtx: ExtractContext + From + Serialize + DeserializeOwned + Send + Sync + 'static { // TODO: Remove serde bounds here async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -187,7 +198,6 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = map_transport_to_client(conn); let tarpc::client::NewClient { client: subscriber, dispatch, @@ -211,11 +221,11 @@ impl Publisher { async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics if let Ok(topics) = subscriber - .topics(&mut context::ClientContext::current()) + .topics(&mut ClientCtx::from(context::SharedContext::current())) .await { self.clients.lock().unwrap().insert( @@ -269,8 +279,8 @@ impl Publisher { } } -impl publisher::Publisher for Publisher { - type Context = SharedContext; +impl publisher::Publisher for Publisher where ClientCtx: ExtractContext + From + Send + Sync + 'static { + type Context = ClientCtx; async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { @@ -283,7 +293,7 @@ impl publisher::Publisher for Publisher { publications.push(async { client .receive( - &mut context::ClientContext::current(), + &mut ClientCtx::from(context::SharedContext::current()), topic.clone(), message.clone(), ) @@ -333,7 +343,7 @@ pub fn init_tracing( async fn main() -> anyhow::Result<()> { let tracer_provider = init_tracing("Pub/Sub")?; - let addrs = Publisher { + let addrs = Publisher:: { clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } @@ -354,13 +364,13 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - map_transport_to_client(tcp::connect(addrs.publisher, Json::default).await?), + tcp::connect(addrs.publisher, Json::default).await?, ) .spawn(); publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "calculus".into(), "sqrt(2)".into(), ) @@ -368,7 +378,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "cool shorts".into(), "hello to all".into(), ) @@ -376,7 +386,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "history".into(), "napoleon".to_string(), ) @@ -386,7 +396,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index db93d2e74..359b4af8b 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{ClientContext, SharedContext}; +use tarpc::context::{SharedContext}; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -38,7 +38,7 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = - transport::channel::unbounded_for_client_server_context(); + transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); @@ -51,7 +51,7 @@ async fn main() -> anyhow::Result<()> { // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. let hello = client - .hello(&mut context::ClientContext::current(), "Stim".to_string()) + .hello(&mut context::SharedContext::current(), "Stim".to_string()) .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 0e00cdca8..c203bf0b8 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,12 +10,11 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::{ClientContext, SharedContext}; +use tarpc::context::{SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -145,10 +144,9 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = map_transport_to_client(transport); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut ClientContext::current()) + .ping(&mut SharedContext::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index b69e0c1a0..525a16a47 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -19,8 +19,8 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use tarpc::context::{ClientContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client}; +use std::marker::PhantomData; +use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -65,18 +65,20 @@ impl AddService for AddServer { } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, + ghost: PhantomData } -impl DoubleService for DoubleServer +impl DoubleService for DoubleServer where - Stub: AddStub + Clone + Send + Sync + 'static, + Stub: AddStub + Clone + Send + Sync + 'static, + ClientCtx: From + Send + Sync + 'static { type Context = SharedContext; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(&mut context::ClientContext::current(), x, x) + .add(&mut ClientCtx::from(context::SharedContext::current()), x, x) .await .map_err(|e| e.to_string()) } @@ -127,18 +129,19 @@ where Ok((listener, addr)) } -fn make_stub( - backends: [impl Transport>, Response> +fn make_stub( + backends: [impl Transport>, Response> + Send + Sync + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, - load_balance::RoundRobin, Resp>>, + load_balance::RoundRobin, Resp, ClientCtx>>, > where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static { let stub = load_balance::RoundRobin::new( backends @@ -184,8 +187,8 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(spawn_incoming(add_server.execute(server))); let add_client = add::AddClient::from(make_stub([ - map_transport_to_client(tarpc::serde_transport::tcp::connect(addr1, Json::default).await?), - map_transport_to_client(tarpc::serde_transport::tcp::connect(addr2, Json::default).await?), + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) @@ -193,11 +196,10 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer { add_client }.serve(); + let server = DoubleServer::<_, SharedContext> { add_client, ghost: PhantomData }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = map_transport_to_client(to_double_server); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); @@ -205,7 +207,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!( "{:?}", double_client - .double(&mut context::ClientContext::current(), 1) + .double(&mut context::SharedContext::current(), 1) .await? ); } diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index b6763a9b2..40ba7e461 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,7 +9,7 @@ mod in_flight_requests; pub mod stub; -use crate::context::{ClientContext, ExtractContext, SharedContext}; +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -30,6 +30,7 @@ use std::{ }, time::SystemTime, }; +use std::marker::PhantomData; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -96,27 +97,32 @@ const _CHECK_USIZE: () = assert!( /// Handles communication from the client to request dispatch. #[derive(Debug)] -pub struct Channel { +pub struct Channel { to_dispatch: mpsc::Sender>, /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. next_request_id: Arc, + + ///TODO: Document + ghost: PhantomData } -impl Clone for Channel { +impl Clone for Channel { fn clone(&self) -> Self { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), + ghost: PhantomData } } } -impl Channel +impl Channel where Req: RequestName, + ClientCtx: ExtractContext { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -129,9 +135,9 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call>( + pub async fn call( &self, - ctx: &mut Ctx, + ctx: &mut ClientCtx, request: Req, ) -> Result { let span = Span::current(); @@ -245,12 +251,12 @@ impl Drop for ResponseGuard<'_, Resp> { /// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the /// channel. -pub fn new( +pub fn new( config: Config, transport: C, -) -> NewClient, RequestDispatch> +) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -260,6 +266,7 @@ where to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }, dispatch: RequestDispatch { config, @@ -268,6 +275,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, + ghost: PhantomData }, } } @@ -277,7 +285,7 @@ where #[must_use] #[pin_project()] #[derive(Debug)] -pub struct RequestDispatch { +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. #[pin] transport: Fuse, @@ -294,11 +302,14 @@ pub struct RequestDispatch { /// RequestDispatch::poll, which relies on downcasting the Any to a concrete error type /// determined within the poll function. terminal_error: Option>, + + ghost: PhantomData, } -impl RequestDispatch +impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, @@ -321,7 +332,7 @@ where fn start_send( self: &mut Pin<&mut Self>, - message: ClientMessage, + message: ClientMessage, ) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -531,7 +542,7 @@ where let request = ClientMessage::Request(Request { id: request_id, message: request, - context: ClientContext::new(ctx), + context: ctx.into(), }); self.in_flight_requests() @@ -581,10 +592,10 @@ where } /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server).map(|m| (response.context.shared_context, m)), + response.message.map_err(RpcError::Server).map(|m| (response.context.extract(), m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -659,9 +670,10 @@ where } } -impl Future for RequestDispatch +impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From { type Output = Result<(), ChannelError>; @@ -704,7 +716,7 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::context::{ClientContext, SharedContext}; + use crate::context::{SharedContext}; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, @@ -735,7 +747,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let context = ClientContext::current(); + let context = SharedContext::current(); dispatch .in_flight_requests @@ -750,7 +762,7 @@ mod tests { server_channel .send(Response { request_id: 0, - context: ClientContext::current(), + context: SharedContext::current(), message: Ok("Resp".into()), }) .await @@ -796,7 +808,7 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -826,7 +838,7 @@ mod tests { &mut server_channel, Response { request_id: 0, - context: ClientContext::current(), + context: SharedContext::current(), message: Ok("hello".into()), }, ) @@ -837,7 +849,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -853,7 +865,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -874,7 +886,7 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -890,7 +902,7 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -907,11 +919,11 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send let resp = channel - .call(&mut ClientContext::current(), "hi".to_string()) + .call(&mut SharedContext::current(), "hi".to_string()) .await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -919,7 +931,7 @@ mod tests { #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; @@ -942,7 +954,7 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; assert_eq!( @@ -959,7 +971,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -969,7 +981,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -979,7 +991,7 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -988,34 +1000,36 @@ mod tests { } /// Sets up a RequestDispatch with a transport that always errors. - fn set_up_always_err( + fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, - Channel, + Pin>>>, + Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { + let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); + let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData }); let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) } - struct AlwaysErrorTransport(TransportError, PhantomData); + struct AlwaysErrorTransport(TransportError, PhantomData<( I, ClientCtx)>); #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] enum TransportError { @@ -1032,7 +1046,7 @@ mod tests { } } - impl Sink for AlwaysErrorTransport { + impl Sink for AlwaysErrorTransport { type Error = TransportError; fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { match self.0 { @@ -1064,8 +1078,8 @@ mod tests { } } - impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; + impl Stream for AlwaysErrorTransport { + type Item = Result, TransportError>; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { if matches!(self.0, TransportError::Read) { Poll::Ready(Some(Err(self.0))) @@ -1075,21 +1089,22 @@ mod tests { } } - fn set_up() -> ( + fn set_up() -> ( Pin< Box< RequestDispatch< String, String, + ClientCtx, UnboundedChannel< - Response, - ClientMessage, + Response, + ClientMessage, >, >, >, >, - Channel, - UnboundedChannel, Response>, + Channel, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1097,26 +1112,28 @@ mod tests { let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); - let dispatch = RequestDispatch:: { + let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData }; let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }; (Box::pin(dispatch), channel, server_channel) } - async fn send_request<'a>( - channel: &'a mut Channel, + async fn send_request<'a, ClientCtx>( + channel: &'a mut Channel, request: &str, response_completion: oneshot::Sender>, response: &'a mut oneshot::Receiver>, @@ -1140,8 +1157,8 @@ mod tests { response_guard } - async fn reserve_for_send<'a>( - channel: &'a mut Channel, + async fn reserve_for_send<'a, ClientCtx>( + channel: &'a mut Channel, response_completion: oneshot::Sender>, response: &'a mut oneshot::Receiver>, ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { @@ -1166,12 +1183,12 @@ mod tests { } } - async fn send_response( + async fn send_response( channel: &mut UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, - response: Response, + response: Response, ) { channel.send(response).await.unwrap(); } diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 9989f0577..992f6d611 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -5,7 +5,7 @@ use crate::{ client::{Channel, RpcError}, server::Serve, }; -use crate::context::{ClientContext, SharedContext}; +use crate::context::{ExtractContext, SharedContext}; pub mod load_balance; pub mod retry; @@ -24,27 +24,28 @@ pub trait Stub { type Resp; ///TODO: document - type ServerCtx; + type ClientCtx; /// Calls a remote service. async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result; } -impl Stub for Channel +impl Stub for Channel where Req: RequestName, + ClientCtx: ExtractContext { type Req = Req; type Resp = Resp; - type ServerCtx = ClientContext; + type ClientCtx = ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Req, ) -> Result { Self::call(self, ctx, request).await @@ -57,13 +58,13 @@ where { type Req = S::Req; type Resp = S::Resp; - type ServerCtx = ClientContext; + type ClientCtx = SharedContext; async fn call( &self, - ctx: &mut ClientContext, + ctx: &mut Self::ClientCtx, req: Self::Req, ) -> Result { - let mut server_ctx = ctx.shared_context.clone(); + let mut server_ctx = ctx.clone(); let res = self .clone() @@ -71,7 +72,7 @@ where .await .map_err(RpcError::Server); - ctx.shared_context = server_ctx; + *ctx = server_ctx; res } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 5b319c6c8..60efafc91 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -16,11 +16,11 @@ mod round_robin { { type Req = Stub::Req; type Resp = Stub::Resp; - type ServerCtx = Stub::ServerCtx; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -115,11 +115,11 @@ mod consistent_hash { { type Req = Stub::Req; type Resp = Stub::Resp; - type ServerCtx = Stub::ServerCtx; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -201,17 +201,17 @@ mod consistent_hash { for _ in 0..2 { let resp = stub - .call(&mut context::ClientContext::current(), 'a') + .call(&mut context::SharedContext::current(), 'a') .await?; assert_eq!(resp, 1); let resp = stub - .call(&mut context::ClientContext::current(), 'b') + .call(&mut context::SharedContext::current(), 'b') .await?; assert_eq!(resp, 2); let resp = stub - .call(&mut context::ClientContext::current(), 'c') + .call(&mut context::SharedContext::current(), 'c') .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 9a22d101e..577ef5362 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -31,11 +31,11 @@ where { type Req = Req; type Resp = Resp; - type ServerCtx = ServerCtx; + type ClientCtx = ServerCtx; async fn call( &self, - _: &mut Self::ServerCtx, + _: &mut Self::ClientCtx, request: Self::Req, ) -> Result { self.responses diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 2cf950aed..5499f60e4 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -14,11 +14,11 @@ where { type Req = Req; type Resp = Stub::Resp; - type ServerCtx = Stub::ServerCtx; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 5cc9389f1..5ca5b8256 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -56,74 +56,6 @@ impl ExtractContext for T where T: Clone { } } -/// Request context that carries request-scoped client side information like deadlines and trace information -/// as well as any server side extensions defined by the transport, hooks and stubs. -/// The shared part of the context is sent from client to server, while the client side extensions are only seen on the client side. -/// -/// The context should not be stored directly in a stub implementation, because the context will -/// be different for each request in scope. -#[derive(Debug)] -pub struct ClientContext { - /// Shared context sent from client to server which contains information used by both sides. - pub shared_context: SharedContext, - - /// Client side extensions that are not seen by the server - /// XXX, YYY, and ZZZ can use this to store per-request data, and communicate with eachother. - /// Note that this is NOT sent to the server, and they will always see an empty map here. - pub client_context: anymap3::Map, -} - -impl ClientContext { - /// Creates a new ServerContext from the given SharedContext with no extensions. - pub fn new(shared_context: SharedContext) -> Self { - Self { - shared_context, - client_context: anymap3::Map::new(), - } - } - - /// Creates a new ServerContext for the current shared context with no extensions. - pub fn current() -> Self { - Self::new(SharedContext::current()) - } -} - -impl ExtractContext for ClientContext { - fn extract(&self) -> SharedContext { - self.shared_context.clone() - } - - fn update(&mut self, value: SharedContext) { - self.shared_context = value - } -} - -impl ExtractContext for ServerContext { - fn extract(&self) -> SharedContext { - self.shared_context.clone() - } - - fn update(&mut self, value: SharedContext) { - self.shared_context = value - } -} - - - -impl Deref for ClientContext { - type Target = SharedContext; - - fn deref(&self) -> &Self::Target { - &self.shared_context - } -} - -impl DerefMut for ClientContext { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.shared_context - } -} - #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index e0869d9f6..fc79e3056 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -145,7 +145,7 @@ //! # use tarpc::{ //! # ClientMessage, //! # client, context, -//! # context::{ClientContext, SharedContext}, +//! # context::{SharedContext}, //! # transport::channel, //! # server::{self, Channel}, //! # }; @@ -172,7 +172,8 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let (client_transport, server_transport) = channel::unbounded_for_client_server_context(); +//! use futures::future::Shared; +//! let (client_transport, server_transport) = channel::unbounded(); //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) @@ -183,12 +184,12 @@ //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. -//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn(); +//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::ClientContext::current(); +//! let mut context = context::SharedContext::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index e6c395836..7d345a203 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -363,7 +363,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext}, + /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -372,7 +372,7 @@ where /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -383,7 +383,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -407,7 +407,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext}}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{SharedContext}}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -416,7 +416,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( @@ -424,7 +424,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -767,7 +767,7 @@ where /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; - /// use tarpc::context::{ClientContext, SharedContext}; + /// use tarpc::context::{SharedContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -775,7 +775,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( @@ -783,7 +783,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -872,7 +872,7 @@ impl InFlightRequest { /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext}, + /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -881,7 +881,7 @@ impl InFlightRequest { /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -892,7 +892,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 2baa27c89..6a71124b1 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -50,7 +50,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, -/// context::{ClientContext, SharedContext}, +/// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, @@ -59,7 +59,7 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); +/// let (tx, rx) = transport::channel::unbounded(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// @@ -67,7 +67,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::ClientContext::current(); +/// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 476f60738..de9a8afdc 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,14 +6,9 @@ //! Transports backed by in-memory channels. -use crate::context::{ClientContext, SharedContext}; -use crate::{ClientMessage, Response, Transport}; -use futures::future::{Ready}; -use futures::sink::With; -use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; +use futures::{Sink, Stream, task::*}; use pin_project::pin_project; -use std::{error::Error, future, pin::Pin}; -use std::convert::identity; +use std::{error::Error, pin::Pin}; use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. @@ -44,97 +39,6 @@ pub fn unbounded() -> ( ) } -/// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's -/// [`Sink`]. -pub fn unbounded_mapped< - SerializedSinkItem, - SerializedItem, - ClientSinkItem, - ServerSinkItem, - ClientItem, - ServerItem, - F, - G, - H, - I, ->( - mut f: F, - mut g: G, - mut h: H, - mut i: I, -) -> ( - impl Transport, - impl Transport, -) -where - F: FnMut(ClientSinkItem) -> SerializedSinkItem, - G: FnMut(SerializedSinkItem) -> ServerSinkItem, - H: FnMut(SerializedItem) -> ClientItem, - I: FnMut(ServerItem) -> SerializedItem, -{ - let (client, server) = unbounded(); - - let client = client - .with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))) - .map_ok(move |msg: SerializedItem| h(msg)); - let server = server - .map_ok(move |msg: SerializedSinkItem| g(msg)) - .with(move |msg: ServerItem| future::ready(Ok(i(msg)))); - - (client, server) -} - -/// Convenience functino to return two mapped unbounded channel peers for a basechannel and a client implementation. Each [`Stream`] yields items sent through the other's -/// [`Sink`]. -pub fn unbounded_for_client_server_context() -> ( - impl Transport, Response>, - impl Transport, ClientMessage>, -) { - unbounded_mapped( - map_req_client_context_to_shared, - identity, - map_resp_shared_context_to_client, - identity, - ) -} - -/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. -fn map_req_client_context_to_shared( - msg: ClientMessage, -) -> ClientMessage { - msg.map_context(|ctx| ctx.shared_context) -} - -/// Convenience function to map a ClientMessage with SharedContext to one with ClientContext. -fn map_resp_shared_context_to_client( - msg: Response, -) -> Response { - msg.map_context(ClientContext::new) -} - -/// TODO: document -/// Yuck, but impl trait will loose our ability to do t.as_ref() -pub fn map_transport_to_client( - t: T, -) -> futures::stream::MapOk< - With< - T, - ClientMessage, - ClientMessage, - Ready, E>>, - fn(ClientMessage) -> Ready, E>>, - >, - fn(Response) -> Response, -> -where - T: Transport, Response>, - E: From -{ - let f: fn(ClientMessage) -> Ready, E>> = |resp| futures::future::ok(map_req_client_context_to_shared(resp)); - - t.with(f).map_ok(map_resp_shared_context_to_client) -} - /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] @@ -271,6 +175,7 @@ mod tests { use futures::{prelude::*, stream}; use std::io; use tracing::trace; + use crate::context::SharedContext; #[test] fn ensure_is_transport() { @@ -284,12 +189,12 @@ mod tests { let _ = tracing_subscriber::fmt::try_init(); let (client_channel, server_channel) = - transport::channel::unbounded_for_client_server_context(); + transport::channel::unbounded(); tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| { + .execute(serve(|_ctx: &mut SharedContext, request: String| { async move { request.parse::().map_err(|_| { ServerError::new( @@ -308,10 +213,10 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); let response1 = client - .call(&mut context::ClientContext::current(), "123".into()) + .call(&mut context::SharedContext::current(), "123".into()) .await; let response2 = client - .call(&mut context::ClientContext::current(), "abc".into()) + .call(&mut context::SharedContext::current(), "abc".into()) .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d3237..a5238fe8b 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; - +use tarpc::context::SharedContext; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e652cc8e8..e0ec77ff3 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 1a1b5207a..6bcd255c4 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,7 +1,6 @@ use futures::prelude::*; -use tarpc::context::{ClientContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client}; -use tarpc::{ClientMessage, serde_transport}; +use tarpc::context::{SharedContext}; +use tarpc::{serde_transport}; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, @@ -53,12 +52,10 @@ async fn test_call() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = map_transport_to_client(transport); - let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::ClientContext::current(), TestData::White) + .get_opposite_color(&mut context::SharedContext::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 4692b17cd..7d1f96e18 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,7 +4,6 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, client::{self}, @@ -38,7 +37,7 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = channel::unbounded(); let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); @@ -51,7 +50,7 @@ async fn sequential() { ); assert_eq!( client - .call(&mut context::ClientContext::current(), 1) + .call(&mut context::SharedContext::current(), 1) .await .unwrap(), 2 @@ -79,14 +78,14 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = transport::channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::ClientContext::current(); + let mut ctx = context::SharedContext::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -125,17 +124,16 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!( client - .add(&mut context::ClientContext::current(), 1, 2) + .add(&mut context::SharedContext::current(), 1, 2) .await, Ok(3) ); assert_matches!( - client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await, + client.hey(&mut context::SharedContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -163,16 +161,15 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail let res1 = client - .add(&mut context::ClientContext::current(), 1, 2) + .add(&mut context::SharedContext::current(), 1, 2) .await; let res2 = client - .hey(&mut context::ClientContext::current(), "Tim".to_string()) + .hey(&mut context::SharedContext::current(), "Tim".to_string()) .await; assert_matches!(res1, Ok(3)); @@ -185,7 +182,7 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -197,7 +194,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::ClientContext::current(); + let mut context = context::SharedContext::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -215,7 +212,7 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -227,9 +224,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::ClientContext::current(); - let mut context2 = context::ClientContext::current(); - let mut context3 = context::ClientContext::current(); + let mut context1 = context::SharedContext::current(); + let mut context2 = context::SharedContext::current(); + let mut context3 = context::SharedContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -252,8 +249,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); - + let (tx, rx) = transport::channel::unbounded(); tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -262,8 +258,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::ClientContext::current(); - let mut context2 = context::ClientContext::current(); + let mut context1 = context::SharedContext::current(); + let mut context2 = context::SharedContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -292,7 +288,7 @@ async fn counter() -> anyhow::Result<()> { } } - let (tx, rx) = channel::unbounded_for_client_server_context(); + let (tx, rx) = channel::unbounded(); tokio::task::spawn(async move { let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -305,11 +301,11 @@ async fn counter() -> anyhow::Result<()> { let client = CounterClient::new(client::Config::default(), tx).spawn(); assert_matches!( - client.count(&mut context::ClientContext::current()).await, + client.count(&mut context::SharedContext::current()).await, Ok(1) ); assert_matches!( - client.count(&mut context::ClientContext::current()).await, + client.count(&mut context::SharedContext::current()).await, Ok(2) ); From 8bd243a19d2f5a207e45539d7835b49a7ed83911 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:15:32 +0100 Subject: [PATCH 15/23] fix merge conflict --- tarpc/src/context.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 5ca5b8256..e89a7f044 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -10,7 +10,6 @@ use crate::trace::{self, TraceId}; use opentelemetry::trace::TraceContextExt; use static_assertions::assert_impl_all; -use std::ops::{Deref, DerefMut}; use std::{ convert::TryFrom, time::{Duration, Instant}, From cf3fa5371e0f4ed0fe978e46cf4c7ed3b23fd9d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:16:14 +0100 Subject: [PATCH 16/23] run cargo fmt --- example-service/src/client.rs | 2 +- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 27 +++---- tarpc/examples/compression.rs | 2 +- tarpc/examples/custom_transport.rs | 2 +- tarpc/examples/pubsub.rs | 27 +++++-- tarpc/examples/readme.rs | 5 +- tarpc/examples/tls_over_tcp.rs | 2 +- tarpc/examples/tracing.rs | 20 +++-- tarpc/src/client.rs | 73 ++++++++++--------- tarpc/src/client/in_flight_requests.rs | 15 +++- tarpc/src/client/stub.rs | 10 +-- tarpc/src/client/stub/load_balance.rs | 8 +- tarpc/src/client/stub/mock.rs | 12 +-- tarpc/src/context.rs | 7 +- tarpc/src/server.rs | 56 +++++++------- .../src/server/limits/requests_per_channel.rs | 14 ++-- tarpc/src/server/request_hook/after.rs | 12 +-- tarpc/src/server/request_hook/before.rs | 42 ++++------- .../server/request_hook/before_and_after.rs | 13 ++-- tarpc/src/server/testing.rs | 9 ++- tarpc/src/transport/channel.rs | 5 +- tarpc/tests/dataservice.rs | 4 +- tarpc/tests/service_functional.rs | 2 +- 24 files changed, 186 insertions(+), 185 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index b8ff22c97..627e67504 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -8,7 +8,7 @@ use clap::Parser; use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 019a2d7b1..9c9160e17 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -16,7 +16,7 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index cf107d0ad..1a5b7e6db 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -550,22 +550,19 @@ impl ServiceGenerator<'_> { .. } = self; - let rpc_fns = rpcs - .iter() - .zip(return_types.iter()) - .map( - |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { - quote! { - #( #attrs )* - async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; - } + let rpc_fns = rpcs.iter().zip(return_types.iter()).map( + |( + RpcMethod { + attrs, ident, args, .. }, - ); + output, + )| { + quote! { + #( #attrs )* + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; + } + }, + ); let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index f201521ad..1a3a7d566 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,13 +9,13 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; +use tarpc::context::SharedContext; use tarpc::{ client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, }; -use tarpc::context::SharedContext; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index c9eb871ea..859bed0ed 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,7 +6,7 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 07a93becf..5e915e1b0 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -40,6 +40,9 @@ use futures::{ }; use opentelemetry::trace::TracerProvider as _; use publisher::Publisher as _; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::ops::Shl; use std::{ collections::HashMap, error::Error, @@ -47,9 +50,6 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; -use std::ops::Shl; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; use subscriber::Subscriber as _; use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ @@ -140,7 +140,8 @@ struct Subscription { #[derive(Debug)] struct Publisher { clients: Arc>>, - subscriptions: Arc>>>>, + subscriptions: + Arc>>>>, } impl Clone for Publisher { @@ -161,7 +162,17 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } -impl Publisher where ClientCtx: ExtractContext + From + Serialize + DeserializeOwned + Send + Sync + 'static { // TODO: Remove serde bounds here +impl Publisher +where + ClientCtx: ExtractContext + + From + + Serialize + + DeserializeOwned + + Send + + Sync + + 'static, +{ + // TODO: Remove serde bounds here async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -178,7 +189,6 @@ impl Publisher where ClientCtx: ExtractContext Publisher where ClientCtx: ExtractContext publisher::Publisher for Publisher where ClientCtx: ExtractContext + From + Send + Sync + 'static { +impl publisher::Publisher for Publisher +where + ClientCtx: ExtractContext + From + Send + Sync + 'static, +{ type Context = ClientCtx; async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 359b4af8b..8c8d6619e 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -37,8 +37,7 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = - transport::channel::unbounded(); + let (client_transport, server_transport) = transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index c203bf0b8..d67340449 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,7 +10,7 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 525a16a47..77b19ba46 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -12,6 +12,7 @@ use crate::{ }; use futures::{future, prelude::*}; use opentelemetry::trace::TracerProvider as _; +use std::marker::PhantomData; use std::{ io, sync::{ @@ -19,7 +20,6 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use std::marker::PhantomData; use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, @@ -67,18 +67,22 @@ impl AddService for AddServer { #[derive(Clone)] struct DoubleServer { add_client: add::AddClient, - ghost: PhantomData + ghost: PhantomData, } impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, - ClientCtx: From + Send + Sync + 'static + ClientCtx: From + Send + Sync + 'static, { type Context = SharedContext; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(&mut ClientCtx::from(context::SharedContext::current()), x, x) + .add( + &mut ClientCtx::from(context::SharedContext::current()), + x, + x, + ) .await .map_err(|e| e.to_string()) } @@ -141,7 +145,7 @@ fn make_stub( where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, - ClientCtx: ExtractContext + From + Send + Sync + 'static + ClientCtx: ExtractContext + From + Send + Sync + 'static, { let stub = load_balance::RoundRobin::new( backends @@ -196,7 +200,11 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer::<_, SharedContext> { add_client, ghost: PhantomData }.serve(); + let server = DoubleServer::<_, SharedContext> { + add_client, + ghost: PhantomData, + } + .serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 40ba7e461..27856c729 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -19,6 +19,7 @@ use crate::{ use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use pin_project::pin_project; +use std::marker::PhantomData; use std::{ any::Any, convert::TryFrom, @@ -30,7 +31,6 @@ use std::{ }, time::SystemTime, }; -use std::marker::PhantomData; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -105,7 +105,7 @@ pub struct Channel { next_request_id: Arc, ///TODO: Document - ghost: PhantomData + ghost: PhantomData, } impl Clone for Channel { @@ -114,7 +114,7 @@ impl Clone for Channel { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), - ghost: PhantomData + ghost: PhantomData, } } } @@ -122,7 +122,7 @@ impl Clone for Channel { impl Channel where Req: RequestName, - ClientCtx: ExtractContext + ClientCtx: ExtractContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -135,11 +135,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call( - &self, - ctx: &mut ClientCtx, - request: Req, - ) -> Result { + pub async fn call(&self, ctx: &mut ClientCtx, request: Req) -> Result { let span = Span::current(); let mut shared_context = ctx.extract(); shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { @@ -148,7 +144,10 @@ where ); shared_context.trace_context.new_child() }); - span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id())); + span.record( + "rpc.trace_id", + tracing::field::display(shared_context.trace_id()), + ); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -266,7 +265,7 @@ where to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), - ghost: PhantomData + ghost: PhantomData, }, dispatch: RequestDispatch { config, @@ -275,7 +274,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, - ghost: PhantomData + ghost: PhantomData, }, } } @@ -309,11 +308,9 @@ pub struct RequestDispatch { impl RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From + ClientCtx: ExtractContext + From, { - fn in_flight_requests<'a>( - self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests { + fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -595,7 +592,10 @@ where fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server).map(|m| (response.context.extract(), m)), + response + .message + .map_err(RpcError::Server) + .map(|m| (response.context.extract(), m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -673,7 +673,7 @@ where impl Future for RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From + ClientCtx: ExtractContext + From, { type Output = Result<(), ChannelError>; @@ -704,7 +704,8 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContext, ///TODO: <-- this should be a &mut ClientContext + pub ctx: context::SharedContext, + ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, @@ -716,7 +717,7 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::context::{SharedContext}; + use crate::context::SharedContext; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, @@ -790,7 +791,8 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((SharedContext::current(), "well done"))).unwrap(); + tx.send(Ok((SharedContext::current(), "well done"))) + .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -902,7 +904,8 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -1003,13 +1006,18 @@ mod tests { fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, + Pin< + Box< + RequestDispatch>, + >, + >, Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); + let transport: AlwaysErrorTransport = + AlwaysErrorTransport(cause, PhantomData); let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, @@ -1017,19 +1025,19 @@ mod tests { in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, - ghost: PhantomData + ghost: PhantomData, }); let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), - ghost: PhantomData + ghost: PhantomData, }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) } - struct AlwaysErrorTransport(TransportError, PhantomData<( I, ClientCtx)>); + struct AlwaysErrorTransport(TransportError, PhantomData<(I, ClientCtx)>); #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] enum TransportError { @@ -1096,10 +1104,7 @@ mod tests { String, String, ClientCtx, - UnboundedChannel< - Response, - ClientMessage, - >, + UnboundedChannel, ClientMessage>, >, >, >, @@ -1119,14 +1124,14 @@ mod tests { in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, - ghost: PhantomData + ghost: PhantomData, }; let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), - ghost: PhantomData + ghost: PhantomData, }; (Box::pin(dispatch), channel, server_channel) @@ -1165,7 +1170,7 @@ mod tests { let permit = channel.to_dispatch.reserve().await.unwrap(); |request| { let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { ctx: SharedContext::current(), span: Span::current(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 5b648098b..0ea5ba5ac 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,4 +1,9 @@ -use crate::{trace, util::{Compact, TimeUntil}}; +use crate::client::RpcError; +use crate::context::SharedContext; +use crate::{ + trace, + util::{Compact, TimeUntil}, +}; use fnv::FnvHashMap; use std::time::Instant; use std::{ @@ -8,8 +13,6 @@ use std::{ use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; -use crate::client::RpcError; -use crate::context::{SharedContext}; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -78,7 +81,11 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true if the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Result<(SharedContext, Res), RpcError>) -> Option { + pub fn complete_request( + &mut self, + request_id: u64, + result: Result<(SharedContext, Res), RpcError>, + ) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 992f6d611..51cececae 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,11 +1,11 @@ //! Provides a Stub trait, implemented by types that can call remote services. +use crate::context::{ExtractContext, SharedContext}; use crate::{ RequestName, client::{Channel, RpcError}, server::Serve, }; -use crate::context::{ExtractContext, SharedContext}; pub mod load_balance; pub mod retry; @@ -37,17 +37,13 @@ pub trait Stub { impl Stub for Channel where Req: RequestName, - ClientCtx: ExtractContext + ClientCtx: ExtractContext, { type Req = Req; type Resp = Resp; type ClientCtx = ClientCtx; - async fn call( - &self, - ctx: &mut Self::ClientCtx, - request: Req, - ) -> Result { + async fn call(&self, ctx: &mut Self::ClientCtx, request: Req) -> Result { Self::call(self, ctx, request).await } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 60efafc91..9664a2aa7 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -5,9 +5,7 @@ pub use round_robin::RoundRobin; /// Provides a stub that load-balances with a simple round-robin strategy. mod round_robin { - use crate::{ - client::{RpcError, stub}, - }; + use crate::client::{RpcError, stub}; use cycle::AtomicCycle; impl stub::Stub for RoundRobin @@ -98,9 +96,7 @@ mod round_robin { /// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use /// the same stub. mod consistent_hash { - use crate::{ - client::{RpcError, stub} - }; + use crate::client::{RpcError, stub}; use std::{ collections::hash_map::RandomState, hash::{BuildHasher, Hash}, diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 577ef5362..171f8918e 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -2,13 +2,13 @@ use crate::{ RequestName, ServerError, client::{RpcError, stub::Stub}, }; -use std::{collections::HashMap, hash::Hash, io}; use std::marker::PhantomData; +use std::{collections::HashMap, hash::Hash, io}; /// A mock stub that returns user-specified responses. pub struct Mock { responses: HashMap, - ghost: PhantomData + ghost: PhantomData, } impl Mock @@ -19,7 +19,7 @@ where pub fn new(responses: [(Req, Resp); N]) -> Self { Self { responses: HashMap::from(responses), - ghost: PhantomData + ghost: PhantomData, } } } @@ -33,11 +33,7 @@ where type Resp = Resp; type ClientCtx = ServerCtx; - async fn call( - &self, - _: &mut Self::ClientCtx, - request: Self::Req, - ) -> Result { + async fn call(&self, _: &mut Self::ClientCtx, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index e89a7f044..bc357e50f 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -34,7 +34,7 @@ pub struct SharedContext { /// When a service handles a request by making requests itself, those requests should /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. - pub trace_context: trace::Context + pub trace_context: trace::Context, } ///TODO @@ -45,7 +45,10 @@ pub trait ExtractContext { fn update(&mut self, value: Ctx); } -impl ExtractContext for T where T: Clone { +impl ExtractContext for T +where + T: Clone, +{ fn extract(&self) -> T { self.clone() } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 7d345a203..1ed69fcd8 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,10 +6,11 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::{SpanExt}, + context::SpanExt, trace, util::TimeUntil, }; @@ -27,7 +28,6 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; -use crate::context::{ExtractContext, SharedContext}; mod in_flight_requests; pub mod request_hook; @@ -59,7 +59,10 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel( + self, + transport: T, + ) -> BaseChannel where T: Transport, ClientMessage>, ServerCtx: ExtractContext, @@ -113,7 +116,10 @@ impl Copy for ServeFn where F: /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut ServerCtx, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -168,7 +174,7 @@ pub struct BaseChannel { impl BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext + ServerCtx: ExtractContext, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -443,7 +449,7 @@ where impl Stream for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext + ServerCtx: ExtractContext, { type Item = Result, ChannelError>; @@ -548,11 +554,12 @@ where } } -impl Sink> for BaseChannel +impl Sink> + for BaseChannel where T: Transport, ClientMessage>, T::Error: Error, - ServerCtx: ExtractContext + ServerCtx: ExtractContext, { type Error = ChannelError; @@ -610,7 +617,6 @@ where T: Transport, ClientMessage>, ServerCtx: ExtractContext, { - type Req = Req; type Resp = Resp; type Transport = T; @@ -995,6 +1001,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; + use crate::context::{ExtractContext, SharedContext}; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1012,7 +1019,6 @@ mod tests { task::Poll, time::{Duration, Instant}, }; - use crate::context::{ExtractContext, SharedContext}; fn test_channel() -> ( Pin< @@ -1024,7 +1030,7 @@ mod tests { ClientMessage, Response, >, - SharedContext + SharedContext, >, >, >, @@ -1045,9 +1051,8 @@ mod tests { ClientMessage, Response, >, - SharedContext + SharedContext, >, - >, >, >, @@ -1073,7 +1078,7 @@ mod tests { ClientMessage, Response, >, - SharedContext + SharedContext, >, >, >, @@ -1114,12 +1119,11 @@ mod tests { #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline where ServerCtx: ExtractContext { - async fn before( - &mut self, - ctx: &mut ServerCtx, - _: &Req, - ) -> Result<(), ServerError> { + impl BeforeRequest for SetDeadline + where + ServerCtx: ExtractContext, + { + async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { let mut inner = ctx.extract(); inner.deadline = self.0; ctx.update(inner); @@ -1159,21 +1163,13 @@ mod tests { } } impl BeforeRequest for PrintLatency { - async fn before( - &mut self, - _: &mut ServerCtx, - _: &Req, - ) -> Result<(), ServerError> { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } } impl AfterRequest for PrintLatency { - async fn after( - &mut self, - _: &mut ServerCtx, - _: &mut Result, - ) { + async fn after(&mut self, _: &mut ServerCtx, _: &mut Result) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index deb723bda..34b372510 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -180,6 +180,7 @@ where mod tests { use super::*; + use crate::context::SharedContext; use crate::server::{ TrackedRequest, testing::{self, FakeChannel, PollExt}, @@ -190,7 +191,6 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; - use crate::context::{SharedContext}; #[tokio::test] async fn throttler_in_flight_requests() { @@ -270,9 +270,10 @@ mod tests { ghost: PhantomData In>, } impl PendingSink<(), ()> { - pub fn default() - -> PendingSink>, Response> - { + pub fn default() -> PendingSink< + io::Result>, + Response, + > { PendingSink { ghost: PhantomData } } } @@ -298,7 +299,10 @@ mod tests { } } impl Channel - for PendingSink>, Response> + for PendingSink< + io::Result>, + Response, + > { type Req = Req; type Resp = Resp; diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index ce6319e25..1fa3cee51 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -15,11 +15,7 @@ pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after( - &mut self, - ctx: &mut ServerCtx, - resp: &mut Result, - ); + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result); } impl AfterRequest for F @@ -27,11 +23,7 @@ where F: FnMut(&mut ServerCtx, &mut Result) -> Fut, Fut: Future, { - async fn after( - &mut self, - ctx: &mut ServerCtx, - resp: &mut Result, - ) { + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result) { self(ctx, resp).await } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 3e2e091c8..1552a0b49 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -6,9 +6,9 @@ //! Provides a hook that runs before request execution. -use std::marker::PhantomData; use crate::{ServerError, server::Serve}; use futures::prelude::*; +use std::marker::PhantomData; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] @@ -20,11 +20,7 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before( - &mut self, - ctx: &mut ServerCtx, - req: &Req, - ) -> Result<(), ServerError>; + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. @@ -64,11 +60,7 @@ where F: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, { - async fn before( - &mut self, - ctx: &mut ServerCtx, - req: &Req, - ) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { self(ctx, req).await } } @@ -77,7 +69,7 @@ where pub struct HookThenServe { serve: Serv, hook: Hook, - ghost: PhantomData + ghost: PhantomData, } impl Clone for HookThenServe { @@ -88,7 +80,11 @@ impl Clone for HookThenServe HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook, ghost: PhantomData } + Self { + serve, + hook, + ghost: PhantomData, + } } } @@ -101,11 +97,7 @@ where type Req = Serv::Req; type Resp = Serv::Resp; - async fn serve( - self, - ctx: &mut ServerCtx, - req: Self::Req, - ) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Self::Req) -> Result { let HookThenServe { serve, mut hook, .. } = self; @@ -154,14 +146,10 @@ pub struct BeforeRequestCons(First, Rest); #[derive(Clone, Copy)] pub struct BeforeRequestNil; -impl, Rest: BeforeRequest, ServerCtx> BeforeRequest - for BeforeRequestCons +impl, Rest: BeforeRequest, ServerCtx> + BeforeRequest for BeforeRequestCons { - async fn before( - &mut self, - ctx: &mut ServerCtx, - req: &Req, - ) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -175,8 +163,8 @@ impl BeforeRequest for BeforeRequestNil { } } -impl, Rest: BeforeRequestList, ServerCtx> BeforeRequestList - for BeforeRequestCons +impl, Rest: BeforeRequestList, ServerCtx> + BeforeRequestList for BeforeRequestCons { type Then = BeforeRequestCons> diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 080c53b21..f3653a513 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -27,7 +27,9 @@ impl HookThenServeThenHook Clone for HookThenServeThenHook { +impl Clone + for HookThenServeThenHook +{ fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -37,7 +39,8 @@ impl Clone for HookThenServeThen } } -impl Serve for HookThenServeThenHook +impl Serve + for HookThenServeThenHook where Req: RequestName, Serv: Serve, @@ -47,11 +50,7 @@ where type Resp = Resp; type ServerCtx = ServerCtx; - async fn serve( - self, - ctx: &mut ServerCtx, - req: Req, - ) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 9a941f711..ce409dd85 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::{SharedContext}; +use crate::context::SharedContext; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -94,7 +94,9 @@ where } } -impl FakeChannel>, Response> { +impl + FakeChannel>, Response> +{ pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); @@ -120,7 +122,8 @@ impl FakeChannel>, Resp impl FakeChannel<(), ()> { pub fn default() - -> FakeChannel>, Response> { + -> FakeChannel>, Response> + { let (request_cancellation, canceled_requests) = cancellations(); let mut x = anymap3::AnyMap::new(); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index de9a8afdc..a698136f0 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -161,6 +161,7 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { + use crate::context::SharedContext; use crate::{ ServerError, client::{self, RpcError}, @@ -175,7 +176,6 @@ mod tests { use futures::{prelude::*, stream}; use std::io; use tracing::trace; - use crate::context::SharedContext; #[test] fn ensure_is_transport() { @@ -188,8 +188,7 @@ mod tests { async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (client_channel, server_channel) = - transport::channel::unbounded(); + let (client_channel, server_channel) = transport::channel::unbounded(); tokio::spawn( stream::once(future::ready(server_channel)) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 6bcd255c4..a39922666 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,6 +1,6 @@ use futures::prelude::*; -use tarpc::context::{SharedContext}; -use tarpc::{serde_transport}; +use tarpc::context::SharedContext; +use tarpc::serde_transport; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 7d1f96e18..fd54b3db6 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,6 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; +use tarpc::context::SharedContext; use tarpc::{ ClientMessage, client::{self}, @@ -13,7 +14,6 @@ use tarpc::{ transport::channel, }; use tokio::join; -use tarpc::context::SharedContext; #[tarpc_plugins::service] trait Service { From 34a87d65f3f9abaaf27a20bc56838dc06e602a88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:35:27 +0100 Subject: [PATCH 17/23] cleanup... --- example-service/src/client.rs | 12 +- example-service/src/server.rs | 4 +- plugins/src/lib.rs | 7 +- plugins/tests/service.rs | 10 +- tarpc/examples/compression.rs | 6 +- tarpc/examples/custom_transport.rs | 8 +- tarpc/examples/pubsub.rs | 24 ++-- tarpc/examples/readme.rs | 6 +- tarpc/examples/tls_over_tcp.rs | 8 +- tarpc/examples/tracing.rs | 16 +-- tarpc/src/client.rs | 67 +++++------ tarpc/src/client/in_flight_requests.rs | 16 +-- tarpc/src/client/stub.rs | 14 +-- tarpc/src/client/stub/load_balance.rs | 6 +- tarpc/src/context.rs | 10 +- tarpc/src/lib.rs | 12 +- tarpc/src/server.rs | 107 +++++++++--------- tarpc/src/server/incoming.rs | 3 +- .../src/server/limits/requests_per_channel.rs | 14 +-- tarpc/src/server/request_hook.rs | 10 +- tarpc/src/server/request_hook/before.rs | 4 +- tarpc/src/server/testing.rs | 17 ++- tarpc/src/transport/channel.rs | 7 +- .../compile_fail/must_use_request_dispatch.rs | 4 +- .../must_use_request_dispatch.stderr | 6 +- tarpc/tests/dataservice.rs | 6 +- tarpc/tests/service_functional.rs | 36 +++--- 27 files changed, 207 insertions(+), 233 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 627e67504..64b2e0a89 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -5,11 +5,9 @@ // https://opensource.org/licenses/MIT. use clap::Parser; -use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::context::SharedContext; -use tarpc::{client, tokio_serde::formats::Json}; +use tarpc::{client, context, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -31,15 +29,13 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = transport.await?; - // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let client = WorldClient::new(client::Config::default(), transport).spawn(); + let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { - let mut context = SharedContext::current(); - let mut context2 = SharedContext::current(); + let mut context = context::Context::current(); + let mut context2 = context::Context::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 9c9160e17..302336c57 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -16,7 +16,7 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, @@ -37,7 +37,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - type Context = SharedContext; + type Context = context::Context; async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 1a5b7e6db..b8f1ff826 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -377,8 +377,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// ```no_run /// use tarpc::{client, transport, service, server::{self, Channel}}; /// use futures_util::{TryStreamExt, sink::SinkExt};/// -/// -/// use tarpc::context::SharedContext; +/// use tarpc::context; /// /// #[service] /// pub trait Calculator { @@ -404,7 +403,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// type Context = SharedContext; +/// type Context = context::Context; /// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } @@ -568,7 +567,7 @@ impl ServiceGenerator<'_> { quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { - type Context: ::tarpc::context::ExtractContext<::tarpc::context::SharedContext>; + type Context: ::tarpc::context::ExtractContext<::tarpc::context::Context>; #( #rpc_fns )* diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index d8213f4d4..3bd8b4c4e 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; -use tarpc::context::SharedContext; +use tarpc::context::Context; #[test] fn att_service_trait() { @@ -13,10 +13,10 @@ fn att_service_trait() { } impl Foo for () { - type Context = SharedContext; + type Context = context::Context; async fn two_part( self, - _: &mut context::SharedContext, + _: &mut context::Context, s: String, i: i32, ) -> (String, i32) { @@ -44,7 +44,7 @@ fn raw_idents() { } impl r#trait for () { - type Context = SharedContext; + type Context = context::Context; async fn r#await( self, _: &mut Self::Context, @@ -72,7 +72,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - type Context = SharedContext; + type Context = context::Context; async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 1a3a7d566..3c5fa6fcf 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,7 +9,7 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ client, context, serde_transport::tcp, @@ -109,7 +109,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = SharedContext; + type Context = context::Context; async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } @@ -140,7 +140,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", client - .hello(&mut context::SharedContext::current(), "friend".into()) + .hello(&mut context::Context::current(), "friend".into()) .await? ); Ok(()) diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 859bed0ed..80f8a03be 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,8 +6,8 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::SharedContext; -use tarpc::serde_transport as transport; +use tarpc::context::Context; +use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -22,7 +22,7 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = SharedContext; + type Context = context::Context; async fn ping(self, _: &mut Self::Context) {} } #[tokio::main] @@ -53,7 +53,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut SharedContext::current()) + .ping(&mut context::Context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 5e915e1b0..3cf95b27e 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -51,7 +51,7 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::context::{ExtractContext, SharedContext}; +use tarpc::context::{ExtractContext}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -84,7 +84,7 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - type Context = SharedContext; + type Context = context::Context; async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } @@ -164,8 +164,8 @@ async fn spawn(fut: impl Future + Send + 'static) { impl Publisher where - ClientCtx: ExtractContext - + From + ClientCtx: ExtractContext + + From + Serialize + DeserializeOwned + Send @@ -235,7 +235,7 @@ where ) { // Populate the topics if let Ok(topics) = subscriber - .topics(&mut ClientCtx::from(context::SharedContext::current())) + .topics(&mut ClientCtx::from(context::Context::current())) .await { self.clients.lock().unwrap().insert( @@ -291,7 +291,7 @@ where impl publisher::Publisher for Publisher where - ClientCtx: ExtractContext + From + Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static, { type Context = ClientCtx; async fn publish(self, _: &mut Self::Context, topic: String, message: String) { @@ -306,7 +306,7 @@ where publications.push(async { client .receive( - &mut ClientCtx::from(context::SharedContext::current()), + &mut ClientCtx::from(context::Context::current()), topic.clone(), message.clone(), ) @@ -356,7 +356,7 @@ pub fn init_tracing( async fn main() -> anyhow::Result<()> { let tracer_provider = init_tracing("Pub/Sub")?; - let addrs = Publisher:: { + let addrs = Publisher:: { clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } @@ -383,7 +383,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "calculus".into(), "sqrt(2)".into(), ) @@ -391,7 +391,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "cool shorts".into(), "hello to all".into(), ) @@ -399,7 +399,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "history".into(), "napoleon".to_string(), ) @@ -409,7 +409,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 8c8d6619e..ff0307d39 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -25,7 +25,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = SharedContext; + type Context = context::Context; async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } @@ -50,7 +50,7 @@ async fn main() -> anyhow::Result<()> { // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. let hello = client - .hello(&mut context::SharedContext::current(), "Stim".to_string()) + .hello(&mut context::Context::current(), "Stim".to_string()) .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index d67340449..07eb6a8ef 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,8 +10,8 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::SharedContext; -use tarpc::serde_transport as transport; +use tarpc::context::Context; +use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -32,7 +32,7 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = SharedContext; + type Context = context::Context; async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut SharedContext::current()) + .ping(&mut context::Context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 77b19ba46..abf1cbbc1 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -20,7 +20,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use tarpc::context::{ExtractContext, SharedContext}; +use tarpc::context::{ExtractContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -58,7 +58,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - type Context = SharedContext; + type Context = context::Context; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } @@ -73,13 +73,13 @@ struct DoubleServer { impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, - ClientCtx: From + Send + Sync + 'static, + ClientCtx: From + Send + Sync + 'static, { - type Context = SharedContext; + type Context = context::Context; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client .add( - &mut ClientCtx::from(context::SharedContext::current()), + &mut ClientCtx::from(context::Context::current()), x, x, ) @@ -145,7 +145,7 @@ fn make_stub( where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, - ClientCtx: ExtractContext + From + Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static, { let stub = load_balance::RoundRobin::new( backends @@ -200,7 +200,7 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer::<_, SharedContext> { + let server = DoubleServer::<_, context::Context> { add_client, ghost: PhantomData, } @@ -215,7 +215,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!( "{:?}", double_client - .double(&mut context::SharedContext::current(), 1) + .double(&mut context::Context::current(), 1) .await? ); } diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 27856c729..fab3b2548 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,7 +9,6 @@ mod in_flight_requests; pub mod stub; -use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -33,6 +32,7 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; +use crate::context::ExtractContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -122,7 +122,7 @@ impl Clone for Channel { impl Channel where Req: RequestName, - ClientCtx: ExtractContext, + ClientCtx: ExtractContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -184,7 +184,7 @@ where /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -212,7 +212,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result<(SharedContext, Resp), RpcError> { + async fn response(mut self) -> Result<(context::Context, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -308,7 +308,7 @@ pub struct RequestDispatch { impl RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From, + ClientCtx: ExtractContext + From, { fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests @@ -673,7 +673,7 @@ where impl Future for RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From, + ClientCtx: ExtractContext + From, { type Output = Result<(), ChannelError>; @@ -704,12 +704,12 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContext, + pub ctx: context::Context, ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -717,12 +717,7 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::context::SharedContext; - use crate::{ - ChannelError, ClientMessage, Response, - client::{Config, in_flight_requests::InFlightRequests}, - transport::{self, channel::UnboundedChannel}, - }; + use crate::{ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, transport::{self, channel::UnboundedChannel}, context}; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; use std::{ @@ -748,7 +743,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let context = SharedContext::current(); + let context = context::Context::current(); dispatch .in_flight_requests @@ -763,7 +758,7 @@ mod tests { server_channel .send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok("Resp".into()), }) .await @@ -791,7 +786,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((SharedContext::current(), "well done"))) + tx.send(Ok((context::Context::current(), "well done"))) .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { @@ -810,7 +805,7 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -840,7 +835,7 @@ mod tests { &mut server_channel, Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok("hello".into()), }, ) @@ -851,7 +846,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -867,7 +862,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -888,7 +883,7 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -905,7 +900,7 @@ mod tests { async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); let (mut dispatch, mut channel, mut cx) = - set_up_always_err::(TransportError::Flush); + set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -922,11 +917,11 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up::(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send let resp = channel - .call(&mut SharedContext::current(), "hi".to_string()) + .call(&mut context::Context::current(), "hi".to_string()) .await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -934,7 +929,7 @@ mod tests { #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; @@ -957,7 +952,7 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; assert_eq!( @@ -974,7 +969,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err::(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -984,7 +979,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err::(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -994,7 +989,7 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -1140,13 +1135,13 @@ mod tests { async fn send_request<'a, ClientCtx>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: SharedContext::current(), + ctx: context::Context::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1164,15 +1159,15 @@ mod tests { async fn reserve_for_send<'a, ClientCtx>( channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { let permit = channel.to_dispatch.reserve().await.unwrap(); |request| { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: SharedContext::current(), + ctx: context::Context::current(), span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 0ea5ba5ac..cc5091fc6 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,9 +1,5 @@ use crate::client::RpcError; -use crate::context::SharedContext; -use crate::{ - trace, - util::{Compact, TimeUntil}, -}; +use crate::{context, trace, util::{Compact, TimeUntil}}; use fnv::FnvHashMap; use std::time::Instant; use std::{ @@ -34,7 +30,7 @@ impl Default for InFlightRequests { struct RequestData { ctx: trace::Context, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -62,7 +58,7 @@ impl InFlightRequests { ctx: trace::Context, deadline: Instant, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -84,7 +80,7 @@ impl InFlightRequests { pub fn complete_request( &mut self, request_id: u64, - result: Result<(SharedContext, Res), RpcError>, + result: Result<(context::Context, Res), RpcError>, ) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); @@ -103,7 +99,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Result<(SharedContext, Res), RpcError> + 'a, + mut result: impl FnMut() -> Result<(context::Context, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -129,7 +125,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Result<(SharedContext, Res), RpcError>, + expired_error: impl Fn() -> Result<(context::Context, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 51cececae..1a0dbfff6 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,11 +1,7 @@ //! Provides a Stub trait, implemented by types that can call remote services. -use crate::context::{ExtractContext, SharedContext}; -use crate::{ - RequestName, - client::{Channel, RpcError}, - server::Serve, -}; +use crate::context::{ExtractContext}; +use crate::{RequestName, client::{Channel, RpcError}, server::Serve, context}; pub mod load_balance; pub mod retry; @@ -37,7 +33,7 @@ pub trait Stub { impl Stub for Channel where Req: RequestName, - ClientCtx: ExtractContext, + ClientCtx: ExtractContext, { type Req = Req; type Resp = Resp; @@ -50,11 +46,11 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; - type ClientCtx = SharedContext; + type ClientCtx = context::Context; async fn call( &self, ctx: &mut Self::ClientCtx, diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 9664a2aa7..4b9d9df3a 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -197,17 +197,17 @@ mod consistent_hash { for _ in 0..2 { let resp = stub - .call(&mut context::SharedContext::current(), 'a') + .call(&mut context::Context::current(), 'a') .await?; assert_eq!(resp, 1); let resp = stub - .call(&mut context::SharedContext::current(), 'b') + .call(&mut context::Context::current(), 'b') .await?; assert_eq!(resp, 2); let resp = stub - .call(&mut context::SharedContext::current(), 'c') + .call(&mut context::Context::current(), 'c') .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index bc357e50f..6db79e49b 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -23,7 +23,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// be different for each request in scope. #[derive(Debug, Clone)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct SharedContext { +pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -111,7 +111,7 @@ mod absolute_to_relative_time { } } -assert_impl_all!(SharedContext: Send, Sync); +assert_impl_all!(Context: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) @@ -126,7 +126,7 @@ impl Default for Deadline { } } -impl SharedContext { +impl Context { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -152,11 +152,11 @@ impl SharedContext { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &SharedContext); + fn set_context(&self, context: &Context); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &SharedContext) { + fn set_context(&self, context: &Context) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index fc79e3056..cb5a64085 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,7 +124,7 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! type Context = context::SharedContext; +//! type Context = context::Context; //! // Each defined rpc generates an async fn that serves the RPC //! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") @@ -145,7 +145,6 @@ //! # use tarpc::{ //! # ClientMessage, //! # client, context, -//! # context::{SharedContext}, //! # transport::channel, //! # server::{self, Channel}, //! # }; @@ -161,7 +160,7 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # type Context = SharedContext; +//! # type Context = context::Context; //! # // Each defined rpc generates an async fn that serves the RPC //! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") @@ -184,12 +183,12 @@ //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. -//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); +//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::SharedContext::current(); +//! let mut context = context::Context::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); @@ -256,7 +255,6 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use crate::context::SharedContext; use std::ops::Deref; use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; @@ -543,7 +541,7 @@ impl ServerError { impl Request where - Ctx: Deref, + Ctx: Deref, { /// Returns the deadline for this request. pub fn deadline(&self) -> &Instant { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 1ed69fcd8..d1a384e32 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,11 +6,10 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::SpanExt, + context, context::SpanExt, trace, util::TimeUntil, }; @@ -28,6 +27,7 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; +use crate::context::ExtractContext; mod in_flight_requests; pub mod request_hook; @@ -65,7 +65,7 @@ impl Config { ) -> BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { BaseChannel::new(self, transport) } @@ -174,7 +174,7 @@ pub struct BaseChannel { impl BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -369,7 +369,6 @@ where /// use tarpc::{ /// ClientMessage, /// context, - /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -389,7 +388,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -413,7 +412,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{SharedContext}}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -430,7 +429,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -449,7 +448,7 @@ where impl Stream for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { type Item = Result, ChannelError>; @@ -559,7 +558,7 @@ impl Sink> where T: Transport, ClientMessage>, T::Error: Error, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { type Error = ChannelError; @@ -615,7 +614,7 @@ impl AsRef for BaseChannel impl Channel for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { type Req = Req; type Resp = Resp; @@ -773,7 +772,6 @@ where /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; - /// use tarpc::context::{SharedContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -789,7 +787,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -878,7 +876,6 @@ impl InFlightRequest { /// use tarpc::{ /// ClientMessage, /// context, - /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -898,7 +895,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -1001,7 +998,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; - use crate::context::{ExtractContext, SharedContext}; + use crate::context::{ExtractContext}; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1027,14 +1024,14 @@ mod tests { Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, - SharedContext, + context::Context, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -1048,15 +1045,15 @@ mod tests { Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, - SharedContext, + context::Context, >, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1075,15 +1072,15 @@ mod tests { Req, Resp, channel::Channel< - ClientMessage, - Response, + ClientMessage, + Response, >, - SharedContext, + context::Context, >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1093,9 +1090,9 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::SharedContext::current(), + context: context::Context::current(), id: 0, message: req, }) @@ -1111,7 +1108,7 @@ mod tests { async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); assert_matches!( - serve.serve(&mut context::SharedContext::current(), 7).await, + serve.serve(&mut context::Context::current(), 7).await, Ok(7) ); } @@ -1121,7 +1118,7 @@ mod tests { struct SetDeadline(Instant); impl BeforeRequest for SetDeadline where - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { let mut inner = ctx.extract(); @@ -1134,7 +1131,7 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::SharedContext, i| { + let serve = serve(move |ctx: &mut context::Context, i| { async move { assert_eq!(ctx.deadline, some_time); Ok(i) @@ -1142,7 +1139,7 @@ mod tests { .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::SharedContext::current(); + let mut ctx = context::Context::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1174,10 +1171,10 @@ mod tests { } } - let serve = serve(move |_: &mut context::SharedContext, i| async move { Ok(i) }.boxed()); + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::SharedContext::current(), 7) + .serve(&mut context::Context::current(), 7) .await?; Ok(()) } @@ -1185,11 +1182,11 @@ mod tests { #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); - let deadline_hook = serve.before(|_: &mut context::SharedContext, _: &i32| async { + let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); let resp: Result = deadline_hook - .serve(&mut context::SharedContext::current(), 7) + .serve(&mut context::Context::current(), 7) .await; assert_matches!(resp, Err(_)); Ok(()) @@ -1203,14 +1200,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: () }), Err(AlreadyExistsError) @@ -1226,7 +1223,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1234,7 +1231,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1257,7 +1254,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1286,7 +1283,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1328,7 +1325,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1351,7 +1348,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1360,7 +1357,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .unwrap(); @@ -1419,7 +1416,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1428,7 +1425,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .unwrap(); @@ -1440,7 +1437,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .await @@ -1451,7 +1448,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1472,7 +1469,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1481,7 +1478,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .unwrap(); @@ -1492,7 +1489,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1502,7 +1499,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 6a71124b1..67d46e330 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -50,7 +50,6 @@ where /// use tarpc::{ /// ClientMessage, /// context, -/// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, @@ -67,7 +66,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::SharedContext::current(); +/// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 34b372510..32b126aa6 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -180,7 +180,6 @@ where mod tests { use super::*; - use crate::context::SharedContext; use crate::server::{ TrackedRequest, testing::{self, FakeChannel, PollExt}, @@ -191,6 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; + use crate::context; #[tokio::test] async fn throttler_in_flight_requests() { @@ -271,8 +271,8 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() -> PendingSink< - io::Result>, - Response, + io::Result>, + Response, > { PendingSink { ghost: PhantomData } } @@ -300,14 +300,14 @@ mod tests { } impl Channel for PendingSink< - io::Result>, - Response, + io::Result>, + Response, > { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = SharedContext; + type ServerCtx = context::Context; fn config(&self) -> &Config { unimplemented!() } @@ -337,7 +337,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(1), }) .unwrap(); diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 4f3d60377..090c4a72c 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -48,7 +48,7 @@ pub trait RequestHook: Serve { /// use std::io; /// /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) - /// .before(|_ctx: &mut context::SharedContext, req: &i32| { + /// .before(|_ctx: &mut context::Context, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -95,13 +95,13 @@ pub trait RequestHook: Serve { /// Ok(i + 1) /// } /// }.boxed()) - /// .after(|_ctx: &mut context::SharedContext, resp: &mut Result| { + /// .after(|_ctx: &mut context::Context, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 1552a0b49..df4873e83 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -129,7 +129,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::SharedContext::current(); +/// let mut context = context::Context::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -219,7 +219,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = crate::context::SharedContext::current(); + let mut context = crate::context::Context::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index ce409dd85..047464cf5 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::SharedContext; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -39,8 +38,8 @@ where } } -impl Sink> - for FakeChannel> +impl Sink> + for FakeChannel> { type Error = io::Error; @@ -50,7 +49,7 @@ impl Sink> fn start_send( mut self: Pin<&mut Self>, - response: Response, + response: Response, ) -> Result<(), Self::Error> { self.as_mut() .project() @@ -72,14 +71,14 @@ impl Sink> } impl Channel - for FakeChannel>, Response> + for FakeChannel>, Response> where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = SharedContext; + type ServerCtx = context::Context; fn config(&self) -> &Config { &self.config @@ -95,14 +94,14 @@ where } impl - FakeChannel>, Response> + FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::SharedContext { + context: context::Context { deadline: Instant::now(), trace_context: Default::default(), }, @@ -122,7 +121,7 @@ impl impl FakeChannel<(), ()> { pub fn default() - -> FakeChannel>, Response> + -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index a698136f0..1ff75e70d 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -161,7 +161,6 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::context::SharedContext; use crate::{ ServerError, client::{self, RpcError}, @@ -193,7 +192,7 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx: &mut SharedContext, request: String| { + .execute(serve(|_ctx: &mut context::Context, request: String| { async move { request.parse::().map_err(|_| { ServerError::new( @@ -212,10 +211,10 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); let response1 = client - .call(&mut context::SharedContext::current(), "123".into()) + .call(&mut context::Context::current(), "123".into()) .await; let response2 = client - .call(&mut context::SharedContext::current(), "abc".into()) + .call(&mut context::Context::current(), "abc".into()) .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index a5238fe8b..812fc4ee7 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; -use tarpc::context::SharedContext; +use tarpc::context::Context; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e0ec77ff3..4fe34df5f 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index a39922666..a2e458361 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,5 @@ use futures::prelude::*; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::serde_transport; use tarpc::{ client, context, @@ -23,7 +23,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - type Context = SharedContext; + type Context = context::Context; async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, @@ -55,7 +55,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::SharedContext::current(), TestData::White) + .get_opposite_color(&mut context::Context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index fd54b3db6..abe1ba0df 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,7 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ ClientMessage, client::{self}, @@ -25,7 +25,7 @@ trait Service { struct Server; impl Service for Server { - type Context = SharedContext; + type Context = context::Context; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } @@ -50,7 +50,7 @@ async fn sequential() { ); assert_eq!( client - .call(&mut context::SharedContext::current(), 1) + .call(&mut context::Context::current(), 1) .await .unwrap(), 2 @@ -68,7 +68,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - type Context = SharedContext; + type Context = context::Context; async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); @@ -85,7 +85,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::SharedContext::current(); + let mut ctx = context::Context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -128,12 +128,12 @@ async fn serde_tcp() -> anyhow::Result<()> { assert_matches!( client - .add(&mut context::SharedContext::current(), 1, 2) + .add(&mut context::Context::current(), 1, 2) .await, Ok(3) ); assert_matches!( - client.hey(&mut context::SharedContext::current(), "Tim".to_string()).await, + client.hey(&mut context::Context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -166,10 +166,10 @@ async fn serde_uds() -> anyhow::Result<()> { // Save results using socket so we can clean the socket even if our test assertions fail let res1 = client - .add(&mut context::SharedContext::current(), 1, 2) + .add(&mut context::Context::current(), 1, 2) .await; let res2 = client - .hey(&mut context::SharedContext::current(), "Tim".to_string()) + .hey(&mut context::Context::current(), "Tim".to_string()) .await; assert_matches!(res1, Ok(3)); @@ -194,7 +194,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::SharedContext::current(); + let mut context = context::Context::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -224,9 +224,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::SharedContext::current(); - let mut context2 = context::SharedContext::current(); - let mut context3 = context::SharedContext::current(); + let mut context1 = context::Context::current(); + let mut context2 = context::Context::current(); + let mut context3 = context::Context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -258,8 +258,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::SharedContext::current(); - let mut context2 = context::SharedContext::current(); + let mut context1 = context::Context::current(); + let mut context2 = context::Context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -281,7 +281,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type Context = SharedContext; + type Context = context::Context; async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 @@ -301,11 +301,11 @@ async fn counter() -> anyhow::Result<()> { let client = CounterClient::new(client::Config::default(), tx).spawn(); assert_matches!( - client.count(&mut context::SharedContext::current()).await, + client.count(&mut context::Context::current()).await, Ok(1) ); assert_matches!( - client.count(&mut context::SharedContext::current()).await, + client.count(&mut context::Context::current()).await, Ok(2) ); From 116c718178158699aa9b6a15d8cbcf7845eb03ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:41:53 +0100 Subject: [PATCH 18/23] cleanup --- example-service/src/client.rs | 1 + example-service/src/lib.rs | 2 ++ example-service/src/server.rs | 9 ++++----- plugins/Cargo.toml | 1 - plugins/src/lib.rs | 10 ++-------- tarpc/examples/compression.rs | 2 +- tarpc/examples/custom_transport.rs | 3 +-- tarpc/examples/pubsub.rs | 6 +++--- tarpc/examples/readme.rs | 4 ++-- tarpc/examples/tls_over_tcp.rs | 2 +- tarpc/examples/tracing.rs | 2 +- 11 files changed, 18 insertions(+), 24 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 64b2e0a89..e1d496f59 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use service::{WorldClient, init_tracing}; diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index 26b49e2ac..fd5031e75 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,8 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] + use opentelemetry::trace::TracerProvider as _; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 302336c57..a8e3324fc 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use futures::{future, prelude::*}; @@ -11,14 +12,12 @@ use rand::{ thread_rng, }; use service::{World, init_tracing}; -use std::ops::Deref; use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::Context; use tarpc::{ - ClientMessage, context, + context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, }; @@ -67,11 +66,11 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); + let server = HelloServer(channel.transport().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index eeab84924..8be746c26 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -30,6 +30,5 @@ proc-macro = true [dev-dependencies] assert-type-eq = "0.1.0" futures = "0.3" -futures-util = "0.3.31" serde = { version = "1.0", features = ["derive"] } tarpc = { path = "../tarpc", features = ["serde1"] } diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index b8f1ff826..8e35ee49d 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -4,13 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] #![recursion_limit = "512"] -extern crate proc_macro; -extern crate proc_macro2; -extern crate quote; -extern crate syn; - use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, format_ident, quote}; @@ -375,9 +371,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}}; -/// use futures_util::{TryStreamExt, sink::SinkExt};/// -/// use tarpc::context; +/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; /// /// #[service] /// pub trait Calculator { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 3c5fa6fcf..aa12147e5 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -3,13 +3,13 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder}; use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::context::Context; use tarpc::{ client, context, serde_transport::tcp, diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 80f8a03be..1548d8e80 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -3,10 +3,9 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] -use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::Context; use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 3cf95b27e..f33e38833 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" /// port. Because both publishers and subscribers initiate their connections to the PubSub @@ -41,8 +42,6 @@ use futures::{ use opentelemetry::trace::TracerProvider as _; use publisher::Publisher as _; use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use std::ops::Shl; use std::{ collections::HashMap, error::Error, @@ -50,10 +49,11 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; +use serde::Serialize; use subscriber::Subscriber as _; use tarpc::context::{ExtractContext}; use tarpc::{ - ClientMessage, client, context, + client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index ff0307d39..f2a98cc87 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -3,11 +3,11 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::context::Context; use tarpc::{ - ClientMessage, client, context, + client, context, server::{self, Channel}, transport, }; diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 07eb6a8ef..b970a6dac 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; use rustls_pemfile::certs; @@ -10,7 +11,6 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::Context; use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index abf1cbbc1..b67d98fe4 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -3,7 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - +#![deny(warnings, unused, dead_code)] #![allow(clippy::type_complexity)] use crate::{ From 044629d651d7b931120890206867521efa54d996 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 17:06:39 +0100 Subject: [PATCH 19/23] cleanup --- example-service/src/client.rs | 4 +- plugins/tests/service.rs | 8 +-- tarpc/Cargo.toml | 2 - tarpc/examples/compression.rs | 11 ++-- tarpc/examples/custom_transport.rs | 6 ++- tarpc/examples/pubsub.rs | 50 ++++-------------- tarpc/examples/readme.rs | 12 ++--- tarpc/examples/tls_over_tcp.rs | 12 +++-- tarpc/examples/tracing.rs | 19 ++----- tarpc/src/client.rs | 14 ++--- tarpc/src/client/in_flight_requests.rs | 15 +++--- tarpc/src/client/stub.rs | 25 ++++----- tarpc/src/client/stub/load_balance.rs | 6 +-- tarpc/src/context.rs | 5 ++ tarpc/src/lib.rs | 2 +- tarpc/src/server.rs | 52 +++++++++---------- tarpc/src/server/incoming.rs | 2 +- .../src/server/limits/requests_per_channel.rs | 2 +- tarpc/src/server/request_hook.rs | 6 +-- tarpc/src/server/request_hook/before.rs | 4 +- tarpc/src/server/testing.rs | 4 -- tarpc/src/transport/channel.rs | 4 +- tarpc/tests/dataservice.rs | 3 +- tarpc/tests/service_functional.rs | 30 +++++------ 24 files changed, 118 insertions(+), 180 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index e1d496f59..6f3930343 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -35,8 +35,8 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { - let mut context = context::Context::current(); - let mut context2 = context::Context::current(); + let mut context = context::current(); + let mut context2 = context::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 3bd8b4c4e..2e450095c 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,7 +1,6 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; -use tarpc::context::Context; #[test] fn att_service_trait() { @@ -14,12 +13,7 @@ fn att_service_trait() { impl Foo for () { type Context = context::Context; - async fn two_part( - self, - _: &mut context::Context, - s: String, - i: i32, - ) -> (String, i32) { + async fn two_part(self, _: &mut Self::Context, s: String, i: i32) -> (String, i32) { (s, i) } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 0a5efc137..778eb0938 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -61,8 +61,6 @@ tracing = { version = "0.1", default-features = false, features = [ tracing-opentelemetry = { version = "0.31.0", default-features = false } opentelemetry = { version = "0.30.0", default-features = false } opentelemetry-semantic-conventions = "0.30.0" -anymap3 = "1.0.1" -serde-value = "0.7" [dev-dependencies] assert_matches = "1.4" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index aa12147e5..c96014eea 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -122,26 +122,21 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; - let addr = incoming.local_addr(); tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); - let transport = add_compression(transport); - BaseChannel::with_defaults(transport) + BaseChannel::with_defaults(add_compression(transport)) .execute(HelloServer.serve()) .for_each(spawn) .await; }); let transport = tcp::connect(addr, Bincode::default).await?; - let transport = add_compression(transport); - let client = WorldClient::new(client::Config::default(), transport).spawn(); + let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn(); println!( "{}", - client - .hello(&mut context::Context::current(), "friend".into()) - .await? + client.hello(&mut context::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 1548d8e80..7fe32bfa7 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,7 +6,8 @@ #![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::{context, serde_transport as transport}; +use tarpc::{context}; +use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -24,6 +25,7 @@ impl PingService for Service { type Context = context::Context; async fn ping(self, _: &mut Self::Context) {} } + #[tokio::main] async fn main() -> anyhow::Result<()> { let bind_addr = "/tmp/tarpc_on_unix_example.sock"; @@ -52,7 +54,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut context::Context::current()) + .ping(&mut context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index f33e38833..70b41fdb3 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -162,6 +162,7 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } +// TODO: Remove serde bounds here impl Publisher where ClientCtx: ExtractContext @@ -172,7 +173,6 @@ where + Sync + 'static, { - // TODO: Remove serde bounds here async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -234,9 +234,7 @@ where subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber - .topics(&mut ClientCtx::from(context::Context::current())) - .await + if let Ok(topics) = subscriber.topics(&mut ClientCtx::from(context::current())).await { self.clients.lock().unwrap().insert( subscriber_addr, @@ -301,16 +299,10 @@ where Some(subscriptions) => subscriptions.clone(), }; let mut publications = Vec::new(); - for client in subscribers.values_mut() { publications.push(async { - client - .receive( - &mut ClientCtx::from(context::Context::current()), - topic.clone(), - message.clone(), - ) - .await + let mut context = ClientCtx::from(context::current()); + client.receive(&mut context, topic.clone(), message.clone(), ).await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -366,14 +358,12 @@ async fn main() -> anyhow::Result<()> { let _subscriber0 = Subscriber::connect( addrs.subscriptions, vec!["calculus".into(), "cool shorts".into()], - ) - .await?; + ).await?; let _subscriber1 = Subscriber::connect( addrs.subscriptions, vec!["cool shorts".into(), "history".into()], - ) - .await?; + ).await?; let publisher = publisher::PublisherClient::new( client::Config::default(), @@ -382,38 +372,18 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish( - &mut context::Context::current(), - "calculus".into(), - "sqrt(2)".into(), - ) - .await?; + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()).await?; publisher - .publish( - &mut context::Context::current(), - "cool shorts".into(), - "hello to all".into(), - ) - .await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()).await?; publisher - .publish( - &mut context::Context::current(), - "history".into(), - "napoleon".to_string(), - ) - .await?; + .publish(&mut context::current(), "history".into(), "napoleon".to_string()).await?; drop(_subscriber0); publisher - .publish( - &mut context::Context::current(), - "cool shorts".into(), - "hello to who?".into(), - ) - .await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ).await?; tracer_provider.shutdown()?; info!("done."); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index f2a98cc87..f8f298921 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -6,11 +6,7 @@ #![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::{ - client, context, - server::{self, Channel}, - transport, -}; +use tarpc::{client, context, server::{self, Channel}}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -37,7 +33,7 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = transport::channel::unbounded(); + let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); @@ -49,9 +45,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client - .hello(&mut context::Context::current(), "Stim".to_string()) - .await?; + let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index b970a6dac..0ba8f2581 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -11,10 +11,6 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::{context, serde_transport as transport}; -use tarpc::server::{BaseChannel, Channel}; -use tarpc::tokio_serde::formats::Bincode; -use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -23,6 +19,12 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; +use tarpc::context; +use tarpc::serde_transport as transport; +use tarpc::server::{BaseChannel, Channel}; +use tarpc::tokio_serde::formats::Bincode; +use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; + #[tarpc::service] pub trait PingService { async fn ping() -> String; @@ -146,7 +148,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut context::Context::current()) + .ping(&mut context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index b67d98fe4..f36db524e 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -78,11 +78,7 @@ where type Context = context::Context; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add( - &mut ClientCtx::from(context::Context::current()), - x, - x, - ) + .add(&mut ClientCtx::from(context::current()), x, x) .await .map_err(|e| e.to_string()) } @@ -134,10 +130,7 @@ where } fn make_stub( - backends: [impl Transport>, Response> - + Send - + Sync - + 'static; N], + backends: [impl Transport>, Response> + Send + Sync + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp, ClientCtx>>, @@ -200,11 +193,7 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer::<_, context::Context> { - add_client, - ghost: PhantomData, - } - .serve(); + let server = DoubleServer::<_, context::Context> { add_client, ghost: PhantomData }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; @@ -215,7 +204,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!( "{:?}", double_client - .double(&mut context::Context::current(), 1) + .double(&mut context::current(), 1) .await? ); } diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index fab3b2548..90f7cac45 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -743,7 +743,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let context = context::Context::current(); + let context = context::current(); dispatch .in_flight_requests @@ -758,7 +758,7 @@ mod tests { server_channel .send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok("Resp".into()), }) .await @@ -786,7 +786,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((context::Context::current(), "well done"))) + tx.send(Ok((context::current(), "well done"))) .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { @@ -835,7 +835,7 @@ mod tests { &mut server_channel, Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok("hello".into()), }, ) @@ -921,7 +921,7 @@ mod tests { drop(dispatch); // error on send let resp = channel - .call(&mut context::Context::current(), "hi".to_string()) + .call(&mut context::current(), "hi".to_string()) .await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -1141,7 +1141,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::Context::current(), + ctx: context::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1167,7 +1167,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::Context::current(), + ctx: context::current(), span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index cc5091fc6..d6424c564 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,14 +1,17 @@ -use crate::client::RpcError; -use crate::{context, trace, util::{Compact, TimeUntil}}; +use crate::{ + context, trace, + util::{Compact, TimeUntil} +}; use fnv::FnvHashMap; -use std::time::Instant; use std::{ collections::hash_map, task::{Context, Poll}, }; +use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; +use crate::client::RpcError; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -77,11 +80,7 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true if the request was found. - pub fn complete_request( - &mut self, - request_id: u64, - result: Result<(context::Context, Res), RpcError>, - ) -> Option { + pub fn complete_request(&mut self, request_id: u64, result: Result<(context::Context, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 1a0dbfff6..2aa6908e3 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,7 +1,12 @@ //! Provides a Stub trait, implemented by types that can call remote services. -use crate::context::{ExtractContext}; -use crate::{RequestName, client::{Channel, RpcError}, server::Serve, context}; +use crate::{ + RequestName, + client::{Channel, RpcError}, + context, + context::ExtractContext, + server::Serve, +}; pub mod load_balance; pub mod retry; @@ -23,11 +28,7 @@ pub trait Stub { type ClientCtx; /// Calls a remote service. - async fn call( - &self, - ctx: &mut Self::ClientCtx, - request: Self::Req, - ) -> Result; + async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) -> Result; } impl Stub for Channel @@ -46,26 +47,22 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; - type ClientCtx = context::Context; + type ClientCtx = S::ServerCtx; async fn call( &self, ctx: &mut Self::ClientCtx, req: Self::Req, ) -> Result { - let mut server_ctx = ctx.clone(); - let res = self .clone() - .serve(&mut server_ctx, req) + .serve(ctx, req) .await .map_err(RpcError::Server); - *ctx = server_ctx; - res } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 4b9d9df3a..43c1c8b23 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -197,17 +197,17 @@ mod consistent_hash { for _ in 0..2 { let resp = stub - .call(&mut context::Context::current(), 'a') + .call(&mut context::current(), 'a') .await?; assert_eq!(resp, 1); let resp = stub - .call(&mut context::Context::current(), 'b') + .call(&mut context::current(), 'b') .await?; assert_eq!(resp, 2); let resp = stub - .call(&mut context::Context::current(), 'c') + .call(&mut context::current(), 'c') .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 6db79e49b..423084c61 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -148,6 +148,11 @@ impl Context { } } +///TODO: Document +pub fn current() -> Context { + Context::current() +} + /// An extension trait for [`tracing::Span`] for propagating tarpc Contexts. pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index cb5a64085..e34722b6d 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -188,7 +188,7 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::Context::current(); +//! let mut context = context::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d1a384e32..a7560132b 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -388,7 +388,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -429,7 +429,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -787,7 +787,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -895,7 +895,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -1092,7 +1092,7 @@ mod tests { fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::Context::current(), + context: context::current(), id: 0, message: req, }) @@ -1108,7 +1108,7 @@ mod tests { async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); assert_matches!( - serve.serve(&mut context::Context::current(), 7).await, + serve.serve(&mut context::current(), 7).await, Ok(7) ); } @@ -1139,7 +1139,7 @@ mod tests { .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::Context::current(); + let mut ctx = context::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1174,7 +1174,7 @@ mod tests { let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::Context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } @@ -1186,7 +1186,7 @@ mod tests { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); let resp: Result = deadline_hook - .serve(&mut context::Context::current(), 7) + .serve(&mut context::current(), 7) .await; assert_matches!(resp, Err(_)); Ok(()) @@ -1200,14 +1200,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: () }), Err(AlreadyExistsError) @@ -1223,7 +1223,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1231,7 +1231,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1254,7 +1254,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1283,7 +1283,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1325,7 +1325,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1348,7 +1348,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1357,7 +1357,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1416,7 +1416,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1425,7 +1425,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1437,7 +1437,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .await @@ -1448,7 +1448,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1469,7 +1469,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1478,7 +1478,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1489,7 +1489,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1499,7 +1499,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 67d46e330..36e942f62 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -66,7 +66,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::Context::current(); +/// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 32b126aa6..4c7c8dbcc 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -337,7 +337,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(1), }) .unwrap(); diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 090c4a72c..cce5998ee 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -101,7 +101,7 @@ pub trait RequestHook: Serve { /// } /// future::ready(()) /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index df4873e83..adfac8e79 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -129,7 +129,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::Context::current(); +/// let mut context = context::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -219,7 +219,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = crate::context::Context::current(); + let mut context = crate::context::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 047464cf5..76df940ce 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -125,10 +125,6 @@ impl FakeChannel<(), ()> { { let (request_cancellation, canceled_requests) = cancellations(); - let mut x = anymap3::AnyMap::new(); - - x.entry::<&str>(); - FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 1ff75e70d..4a4e216c0 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -211,10 +211,10 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); let response1 = client - .call(&mut context::Context::current(), "123".into()) + .call(&mut context::current(), "123".into()) .await; let response2 = client - .call(&mut context::Context::current(), "abc".into()) + .call(&mut context::current(), "abc".into()) .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index a2e458361..5a5b2f8e7 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,4 @@ use futures::prelude::*; -use tarpc::context::Context; use tarpc::serde_transport; use tarpc::{ client, context, @@ -55,7 +54,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::Context::current(), TestData::White) + .get_opposite_color(&mut context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index abe1ba0df..e716437c7 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,9 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::context::Context; use tarpc::{ - ClientMessage, client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, @@ -50,7 +48,7 @@ async fn sequential() { ); assert_eq!( client - .call(&mut context::Context::current(), 1) + .call(&mut context::current(), 1) .await .unwrap(), 2 @@ -85,7 +83,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::Context::current(); + let mut ctx = context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -128,12 +126,12 @@ async fn serde_tcp() -> anyhow::Result<()> { assert_matches!( client - .add(&mut context::Context::current(), 1, 2) + .add(&mut context::current(), 1, 2) .await, Ok(3) ); assert_matches!( - client.hey(&mut context::Context::current(), "Tim".to_string()).await, + client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -166,10 +164,10 @@ async fn serde_uds() -> anyhow::Result<()> { // Save results using socket so we can clean the socket even if our test assertions fail let res1 = client - .add(&mut context::Context::current(), 1, 2) + .add(&mut context::current(), 1, 2) .await; let res2 = client - .hey(&mut context::Context::current(), "Tim".to_string()) + .hey(&mut context::current(), "Tim".to_string()) .await; assert_matches!(res1, Ok(3)); @@ -194,7 +192,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::Context::current(); + let mut context = context::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -224,9 +222,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::Context::current(); - let mut context2 = context::Context::current(); - let mut context3 = context::Context::current(); + let mut context1 = context::current(); + let mut context2 = context::current(); + let mut context3 = context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -258,8 +256,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::Context::current(); - let mut context2 = context::Context::current(); + let mut context1 = context::current(); + let mut context2 = context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -301,11 +299,11 @@ async fn counter() -> anyhow::Result<()> { let client = CounterClient::new(client::Config::default(), tx).spawn(); assert_matches!( - client.count(&mut context::Context::current()).await, + client.count(&mut context::current()).await, Ok(1) ); assert_matches!( - client.count(&mut context::Context::current()).await, + client.count(&mut context::current()).await, Ok(2) ); From b1120173085399d4db1a9f4ea37116a02a754d84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 17:11:43 +0100 Subject: [PATCH 20/23] more --- tarpc/examples/pubsub.rs | 18 ++++++++++++------ tarpc/examples/tracing.rs | 7 +------ tarpc/src/client/stub.rs | 17 ++++------------- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 70b41fdb3..6c0099a97 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -358,12 +358,14 @@ async fn main() -> anyhow::Result<()> { let _subscriber0 = Subscriber::connect( addrs.subscriptions, vec!["calculus".into(), "cool shorts".into()], - ).await?; + ) + .await?; let _subscriber1 = Subscriber::connect( addrs.subscriptions, vec!["cool shorts".into(), "history".into()], - ).await?; + ) + .await?; let publisher = publisher::PublisherClient::new( client::Config::default(), @@ -372,18 +374,22 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()).await?; + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) + .await?; publisher - .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()).await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()) + .await?; publisher - .publish(&mut context::current(), "history".into(), "napoleon".to_string()).await?; + .publish(&mut context::current(), "history".into(), "napoleon".to_string()) + .await?; drop(_subscriber0); publisher - .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ).await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ) + .await?; tracer_provider.shutdown()?; info!("done."); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index f36db524e..0789d0a43 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -201,12 +201,7 @@ async fn main() -> anyhow::Result<()> { double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); for _ in 1..=5 { - tracing::info!( - "{:?}", - double_client - .double(&mut context::current(), 1) - .await? - ); + tracing::info!("{:?}", double_client.double(&mut context::current(), 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 2aa6908e3..5e473566c 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -28,7 +28,8 @@ pub trait Stub { type ClientCtx; /// Calls a remote service. - async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) -> Result; + async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) + -> Result; } impl Stub for Channel @@ -52,17 +53,7 @@ where type Req = S::Req; type Resp = S::Resp; type ClientCtx = S::ServerCtx; - async fn call( - &self, - ctx: &mut Self::ClientCtx, - req: Self::Req, - ) -> Result { - let res = self - .clone() - .serve(ctx, req) - .await - .map_err(RpcError::Server); - - res + async fn call(&self, ctx: &mut Self::ClientCtx, req: Self::Req) -> Result { + self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } From b988a2d39edff521df47d86a1923b96dbc3d6252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 18:42:19 +0100 Subject: [PATCH 21/23] cleanup --- plugins/src/lib.rs | 2 +- tarpc/src/client/stub/load_balance.rs | 20 +++--- tarpc/src/context.rs | 11 +-- tarpc/src/lib.rs | 68 ++----------------- tarpc/src/server/incoming.rs | 3 +- .../src/server/limits/requests_per_channel.rs | 12 +--- tarpc/src/server/request_hook/before.rs | 12 ++-- .../server/request_hook/before_and_after.rs | 3 +- tarpc/src/server/testing.rs | 15 ++-- tarpc/src/transport/channel.rs | 17 ++--- tarpc/tests/service_functional.rs | 48 ++++--------- 11 files changed, 54 insertions(+), 157 deletions(-) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 8e35ee49d..61a2e32a0 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -371,7 +371,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::Context}; /// /// #[service] /// pub trait Calculator { diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 43c1c8b23..eb605ecf9 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -5,7 +5,9 @@ pub use round_robin::RoundRobin; /// Provides a stub that load-balances with a simple round-robin strategy. mod round_robin { - use crate::client::{RpcError, stub}; + use crate::{ + client::{RpcError, stub}, + }; use cycle::AtomicCycle; impl stub::Stub for RoundRobin @@ -96,7 +98,9 @@ mod round_robin { /// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use /// the same stub. mod consistent_hash { - use crate::client::{RpcError, stub}; + use crate::{ + client::{RpcError, stub} + }; use std::{ collections::hash_map::RandomState, hash::{BuildHasher, Hash}, @@ -196,19 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub - .call(&mut context::current(), 'a') - .await?; + let resp = stub.call(&mut context::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub - .call(&mut context::current(), 'b') - .await?; + let resp = stub.call(&mut context::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub - .call(&mut context::current(), 'c') - .await?; + let resp = stub.call(&mut context::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 423084c61..a1b50c72e 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -111,12 +111,18 @@ mod absolute_to_relative_time { } } + assert_impl_all!(Context: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } +/// Returns the context for the current request, or a default Context if no request is active. +pub fn current() -> Context { + Context::current() +} + #[derive(Clone)] struct Deadline(Instant); @@ -148,11 +154,6 @@ impl Context { } } -///TODO: Document -pub fn current() -> Context { - Context::current() -} - /// An extension trait for [`tracing::Span`] for propagating tarpc Contexts. pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index e34722b6d..06385b15c 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -143,9 +143,7 @@ //! # prelude::*, //! # }; //! # use tarpc::{ -//! # ClientMessage, //! # client, context, -//! # transport::channel, //! # server::{self, Channel}, //! # }; //! # // This is the service definition. It looks a lot like a trait definition. @@ -161,6 +159,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! # type Context = context::Context; +//! # //! # // Each defined rpc generates an async fn that serves the RPC //! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") @@ -171,8 +170,7 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! use futures::future::Shared; -//! let (client_transport, server_transport) = channel::unbounded(); +//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) @@ -255,8 +253,7 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use std::ops::Deref; -use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use std::{any::Any, error::Error, io, sync::Arc}; /// A message from a client to a server. #[derive(Debug)] @@ -284,35 +281,8 @@ pub enum ClientMessage { }, } -impl ClientMessage { - /// Creates a new ClientMessage by mapping the context using the provided function. - pub fn map_context(self, f: F) -> ClientMessage - where - F: FnOnce(Ctx) -> Ctx2, - { - match self { - ClientMessage::Request(Request { - context, - id, - message, - }) => ClientMessage::Request(Request { - context: f(context), - id, - message, - }), - ClientMessage::Cancel { - trace_context, - request_id, - } => ClientMessage::Cancel { - trace_context, - request_id, - }, - } - } -} - /// A request from a client to a server. -#[derive(Debug)] +#[derive(Clone, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. @@ -333,9 +303,7 @@ impl RequestName for Arc where Req: RequestName, { - fn name(&self) -> &str { - self.as_ref().name() - } + fn name(&self) -> &str { self.as_ref().name() } } impl RequestName for Box @@ -401,21 +369,6 @@ pub struct Response { /// The response body, or an error if the request failed. pub message: Result, } - -impl Response { - /// Creates a modified Response by mapping the context using the provided function. - pub fn map_context(self, f: F) -> Response - where - F: FnOnce(Ctx) -> Ctx2, - { - Response { - request_id: self.request_id, - context: f(self.context), - message: self.message, - } - } -} - /// An error indicating the server aborted the request early, e.g., due to request throttling. #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] #[error("{kind:?}: {detail}")] @@ -538,17 +491,6 @@ impl ServerError { Self { kind, detail } } } - -impl Request -where - Ctx: Deref, -{ - /// Returns the deadline for this request. - pub fn deadline(&self) -> &Instant { - &self.context.deadline - } -} - #[test] fn test_channel_any_casts() { use assert_matches::assert_matches; diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 36e942f62..568ae4495 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -58,7 +58,8 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded(); +/// use tracing_subscriber::filter::FilterExt; +/// let (tx, rx) = transport::channel::unbounded(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 4c7c8dbcc..527cb6f98 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -270,10 +270,8 @@ mod tests { ghost: PhantomData In>, } impl PendingSink<(), ()> { - pub fn default() -> PendingSink< - io::Result>, - Response, - > { + pub fn default() + -> PendingSink>, Response, > { PendingSink { ghost: PhantomData } } } @@ -299,11 +297,7 @@ mod tests { } } impl Channel - for PendingSink< - io::Result>, - Response, - > - { + for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index adfac8e79..13fc18509 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -80,11 +80,7 @@ impl Clone for HookThenServe HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { - serve, - hook, - ghost: PhantomData, - } + Self { serve, hook, ghost: PhantomData } } } @@ -97,7 +93,11 @@ where type Req = Serv::Req; type Resp = Serv::Resp; - async fn serve(self, ctx: &mut ServerCtx, req: Self::Req) -> Result { + async fn serve( + self, + ctx: &mut ServerCtx, + req: Self::Req + ) -> Result { let HookThenServe { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index f3653a513..934d82ad5 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -39,8 +39,7 @@ impl Clone } } -impl Serve - for HookThenServeThenHook +impl Serve for HookThenServeThenHook where Req: RequestName, Serv: Serve, diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 76df940ce..a92b50fc2 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -38,8 +38,7 @@ where } } -impl Sink> - for FakeChannel> +impl Sink> for FakeChannel> { type Error = io::Error; @@ -47,10 +46,7 @@ impl Sink> self.project().sink.poll_ready(cx).map_err(|e| match e {}) } - fn start_send( - mut self: Pin<&mut Self>, - response: Response, - ) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -70,8 +66,7 @@ impl Sink> } } -impl Channel - for FakeChannel>, Response> +impl Channel for FakeChannel>, Response> where Req: Unpin, { @@ -93,8 +88,7 @@ where } } -impl - FakeChannel>, Response> +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); @@ -124,7 +118,6 @@ impl FakeChannel<(), ()> { -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); - FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 4a4e216c0..47f3e4928 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -188,21 +188,18 @@ mod tests { let _ = tracing_subscriber::fmt::try_init(); let (client_channel, server_channel) = transport::channel::unbounded(); - tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx: &mut context::Context, request: String| { - async move { + .execute(serve(|_ctx: &mut context::Context, request: String| async move { request.parse::().map_err(|_| { ServerError::new( io::ErrorKind::InvalidInput, format!("{request:?} is not an int"), ) }) - } - .boxed() - })) + }.boxed() + )) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), @@ -210,12 +207,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client - .call(&mut context::current(), "123".into()) - .await; - let response2 = client - .call(&mut context::current(), "abc".into()) - .await; + let response1 = client.call(&mut context::current(), "123".into()).await; + let response2 = client.call(&mut context::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index e716437c7..9d3c70e37 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -35,24 +35,16 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| { - async move { Ok(i + 1) }.boxed() - })) + .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) .for_each(|response| response), ); - assert_eq!( - client - .call(&mut context::current(), 1) - .await - .unwrap(), - 2 - ); + assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -76,7 +68,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -124,12 +116,7 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!( - client - .add(&mut context::current(), 1, 2) - .await, - Ok(3) - ); + assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); assert_matches!( client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." @@ -159,16 +146,11 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client - .add(&mut context::current(), 1, 2) - .await; - let res2 = client - .hey(&mut context::current(), "Tim".to_string()) - .await; + let res1 = client.add(&mut context::current(), 1, 2).await; + let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -180,7 +162,7 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -210,7 +192,7 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -247,7 +229,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -298,14 +280,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!( - client.count(&mut context::current()).await, - Ok(1) - ); - assert_matches!( - client.count(&mut context::current()).await, - Ok(2) - ); + assert_matches!(client.count(&mut context::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::current()).await, Ok(2)); Ok(()) } From 21e9223e494749eb2270e50f349abea6416bc4f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 18:47:22 +0100 Subject: [PATCH 22/23] cleanup --- plugins/src/lib.rs | 27 +++++++++++++++------------ tarpc/src/context.rs | 1 - tarpc/src/lib.rs | 4 +++- tarpc/src/server/testing.rs | 3 +-- tarpc/src/transport/channel.rs | 15 +++++++-------- tarpc/tests/service_functional.rs | 3 +-- 6 files changed, 27 insertions(+), 26 deletions(-) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 61a2e32a0..e7c325d42 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -543,18 +543,21 @@ impl ServiceGenerator<'_> { .. } = self; - let rpc_fns = rpcs.iter().zip(return_types.iter()).map( - |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { - quote! { - #( #attrs )* - async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; - } - }, + let rpc_fns = rpcs + .iter() + .zip(return_types.iter()) + .map( + |( + RpcMethod { + attrs, ident, args, .. + }, + output, + )| { + quote! { + #( #attrs )* + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; + } + }, ); let stub_doc = format!("The stub trait for service [`{service_ident}`]."); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index a1b50c72e..d4a6611e0 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -111,7 +111,6 @@ mod absolute_to_relative_time { } } - assert_impl_all!(Context: Send, Sync); fn ten_seconds_from_now() -> Instant { diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 06385b15c..76d9a1815 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -303,7 +303,9 @@ impl RequestName for Arc where Req: RequestName, { - fn name(&self) -> &str { self.as_ref().name() } + fn name(&self) -> &str { + self.as_ref().name() + } } impl RequestName for Box diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index a92b50fc2..39eabdaf5 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -114,8 +114,7 @@ impl FakeChannel>, R } impl FakeChannel<(), ()> { - pub fn default() - -> FakeChannel>, Response> + pub fn default() -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 47f3e4928..35c81fb1e 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -192,14 +192,13 @@ mod tests { stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) .execute(serve(|_ctx: &mut context::Context, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) - }.boxed() - )) + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + }.boxed())) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 9d3c70e37..559521414 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -36,7 +36,6 @@ impl Service for Server { #[tokio::test] async fn sequential() { let (tx, rx) = tarpc::transport::channel::unbounded(); - let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn( @@ -270,7 +269,7 @@ async fn counter() -> anyhow::Result<()> { let (tx, rx) = channel::unbounded(); - tokio::task::spawn(async move { + tokio::task::spawn(async { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); From 835d92c41689b72e2c9869496b932c2e0077d098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 21:14:39 +0100 Subject: [PATCH 23/23] cleanup --- tarpc/src/client.rs | 116 ++++++++++++++++++-------------------------- tarpc/src/server.rs | 115 +++++++++++-------------------------------- 2 files changed, 77 insertions(+), 154 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 90f7cac45..74031d969 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -144,10 +144,7 @@ where ); shared_context.trace_context.new_child() }); - span.record( - "rpc.trace_id", - tracing::field::display(shared_context.trace_id()), - ); + span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id()), ); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -310,7 +307,9 @@ where C: Transport, Response>, ClientCtx: ExtractContext + From, { - fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + fn in_flight_requests<'a>( + self: &'a mut Pin<&mut Self>, + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -327,10 +326,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send( - self: &mut Pin<&mut Self>, - message: ClientMessage, - ) -> Result<(), C::Error> { + fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -539,17 +535,11 @@ where let request = ClientMessage::Request(Request { id: request_id, message: request, - context: ctx.into(), + context: ctx.into(), //TODO: <-- This will actually initialize an empty client context, and the transport will never see the original }); self.in_flight_requests() - .insert_request( - request_id, - trace_context, - deadline, - span.clone(), - response_completion, - ) + .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -571,11 +561,10 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (trace_context, span, request_id) = - match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some(triple) => triple, - None => return Poll::Ready(None), - }; + let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, + None => return Poll::Ready(None), + }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { @@ -704,8 +693,8 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { + ///TODO: this should be a &mut ClientCtx pub ctx: context::Context, - ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, @@ -717,7 +706,12 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::{ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, transport::{self, channel::UnboundedChannel}, context}; + use crate::{ + ChannelError, ClientMessage, Response, + client::{Config, in_flight_requests::InFlightRequests}, + context, + transport::{self, channel::UnboundedChannel} + }; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; use std::{ @@ -747,13 +741,7 @@ mod tests { dispatch .in_flight_requests - .insert_request( - 0, - context.trace_context, - context.deadline, - Span::current(), - tx, - ) + .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) .unwrap(); server_channel .send(Response { @@ -899,8 +887,7 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = - set_up_always_err::(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -920,9 +907,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send - let resp = channel - .call(&mut context::current(), "hi".to_string()) - .await; + let resp = channel.call(&mut context::current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -1001,18 +986,13 @@ mod tests { fn set_up_always_err( cause: TransportError, ) -> ( - Pin< - Box< - RequestDispatch>, - >, - >, + Pin>>>, Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = - AlwaysErrorTransport(cause, PhantomData); + let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, @@ -1132,31 +1112,6 @@ mod tests { (Box::pin(dispatch), channel, server_channel) } - async fn send_request<'a, ClientCtx>( - channel: &'a mut Channel, - request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> ResponseGuard<'a, String> { - let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - let request = DispatchRequest { - ctx: context::current(), - span: Span::current(), - request_id, - request: request.to_string(), - response_completion, - }; - let response_guard = ResponseGuard { - response, - cancellation: &channel.cancellation, - request_id, - cancel: true, - }; - channel.to_dispatch.send(request).await.unwrap(); - response_guard - } - async fn reserve_for_send<'a, ClientCtx>( channel: &'a mut Channel, response_completion: oneshot::Sender>, @@ -1183,6 +1138,31 @@ mod tests { } } + async fn send_request<'a, ClientCtx>( + channel: &'a mut Channel, + request: &str, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> ResponseGuard<'a, String> { + let request_id = + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request = DispatchRequest { + ctx: context::current(), + span: Span::current(), + request_id, + request: request.to_string(), + response_completion, + }; + let response_guard = ResponseGuard { + response, + cancellation: &channel.cancellation, + request_id, + cancel: true, + }; + channel.to_dispatch.send(request).await.unwrap(); + response_guard + } + async fn send_response( channel: &mut UnboundedChannel< ClientMessage, diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index a7560132b..7e08db475 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context, context::SpanExt, + context::{self, SpanExt}, trace, util::TimeUntil, }; @@ -59,10 +59,7 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel( - self, - transport: T, - ) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where T: Transport, ClientMessage>, ServerCtx: ExtractContext, @@ -84,11 +81,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve( - self, - ctx: &mut Self::ServerCtx, - req: Self::Req, - ) -> Result; + async fn serve(self, ctx: &mut Self::ServerCtx, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. @@ -116,10 +109,8 @@ impl Copy for ServeFn where F: /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce( - &'a mut ServerCtx, - Req, - ) -> Pin> + 'a + Send>>, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. + for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -130,6 +121,7 @@ where impl Serve for ServeFn where Req: RequestName, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. for<'a> F: FnOnce( &'a mut ServerCtx, Req, @@ -314,10 +306,7 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport< - Response::Resp>, - TrackedRequest::Req>, - >, + Self: Transport::Resp>, TrackedRequest::Req>>, { /// Type of request item. type Req; @@ -553,8 +542,7 @@ where } } -impl Sink> - for BaseChannel +impl Sink> for BaseChannel where T: Transport, ClientMessage>, T::Error: Error, @@ -569,10 +557,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send( - mut self: Pin<&mut Self>, - response: Response, - ) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -1018,19 +1003,7 @@ mod tests { }; fn test_channel() -> ( - Pin< - Box< - BaseChannel< - Req, - Resp, - UnboundedChannel< - ClientMessage, - Response, - >, - context::Context, - >, - >, - >, + Pin, Response>, context::Context>>>, UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); @@ -1038,21 +1011,7 @@ mod tests { } fn test_requests() -> ( - Pin< - Box< - Requests< - BaseChannel< - Req, - Resp, - UnboundedChannel< - ClientMessage, - Response, - >, - context::Context, - >, - >, - >, - >, + Pin, Response>, context::Context>>>>, UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); @@ -1065,21 +1024,7 @@ mod tests { fn test_bounded_requests( capacity: usize, ) -> ( - Pin< - Box< - Requests< - BaseChannel< - Req, - Resp, - channel::Channel< - ClientMessage, - Response, - >, - context::Context, - >, - >, - >, - >, + Pin, Response>, context::Context>>>>, channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); @@ -1107,10 +1052,7 @@ mod tests { #[tokio::test] async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); - assert_matches!( - serve.serve(&mut context::current(), 7).await, - Ok(7) - ); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1120,7 +1062,11 @@ mod tests { where ServerCtx: ExtractContext, { - async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut ServerCtx, + _: &Req + ) -> Result<(), ServerError> { let mut inner = ctx.extract(); inner.deadline = self.0; ctx.update(inner); @@ -1131,13 +1077,10 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::Context, i| { - async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - } - .boxed() - }); + let serve = serve(move |ctx: &mut context::Context, i| async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; @@ -1160,7 +1103,11 @@ mod tests { } } impl BeforeRequest for PrintLatency { - async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + _: &mut ServerCtx, + _: &Req + ) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } @@ -1185,9 +1132,7 @@ mod tests { let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook - .serve(&mut context::current(), 7) - .await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1393,9 +1338,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request - .execute(serve(|_, _| async { Ok(()) }.boxed())) - .await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut()