diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 8a4ff72eb..6f3930343 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use service::{WorldClient, init_tracing}; @@ -34,10 +35,13 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { + let mut context = context::current(); + let mut context2 = context::current(); + // Send the request twice, just to be safe! ;) tokio::select! { - hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } - hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 } } } .instrument(tracing::info_span!("Two Hellos")) diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index 26b49e2ac..fd5031e75 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,8 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] + use opentelemetry::trace::TracerProvider as _; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 896280c3d..a8e3324fc 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use futures::{future, prelude::*}; @@ -35,7 +36,8 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = context::Context; + async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index da6443edf..e7c325d42 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -4,13 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] #![recursion_limit = "512"] -extern crate proc_macro; -extern crate proc_macro2; -extern crate quote; -extern crate syn; - use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, format_ident, quote}; @@ -375,7 +371,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::Context}; /// /// #[service] /// pub trait Calculator { @@ -401,7 +397,8 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// type Context = context::Context; +/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -547,26 +544,28 @@ impl ServiceGenerator<'_> { } = self; let rpc_fns = rpcs - .iter() - .zip(return_types.iter()) - .map( - |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { - quote! { - #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; - } - }, - ); + .iter() + .zip(return_types.iter()) + .map( + |( + RpcMethod { + attrs, ident, args, .. + }, + output, + )| { + quote! { + #( #attrs )* + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; + } + }, + ); let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { + type Context: ::tarpc::context::ExtractContext<::tarpc::context::Context>; + #( #rpc_fns )* /// Returns a serving function to use with @@ -577,11 +576,11 @@ impl ServiceGenerator<'_> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { } - impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + impl #client_stub_ident for S + where S: ::tarpc::client::stub::Stub { } } @@ -620,9 +619,9 @@ impl ServiceGenerator<'_> { { type Req = #request_ident; type Resp = #response_ident; + type ServerCtx = S::Context; - - async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -711,12 +710,19 @@ impl ServiceGenerator<'_> { quote! { #[allow(unused)] - #[derive(Clone, Debug)] + #[derive(Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](::core::future::Future). #vis struct #client_ident< - Stub = ::tarpc::client::Channel<#request_ident, #response_ident> - >(Stub); + ClientCtx, + Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx> + >(Stub, ::std::marker::PhantomData); + + impl ::std::clone::Clone for #client_ident { + fn clone(&self) -> Self { + Self(self.0.clone(), ::std::marker::PhantomData) + } + } } } @@ -730,32 +736,33 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident { + impl #client_ident { /// Returns a new client stub that sends requests over the given transport. #vis fn new(config: ::tarpc::client::Config, transport: T) -> ::tarpc::client::NewClient< Self, - ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> + ::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage, ::tarpc::Response> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { - client: #client_ident(new_client.client), + client: #client_ident(new_client.client, ::std::marker::PhantomData), dispatch: new_client.dispatch, } } } - impl ::core::convert::From for #client_ident + impl ::core::convert::From for #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { /// Returns a new client stub that sends requests over the given transport. fn from(stub: Stub) -> Self { - #client_ident(stub) + #client_ident::(stub, ::std::marker::PhantomData) } } @@ -778,15 +785,16 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident + impl #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 26ee1ec39..2e450095c 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,16 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + type Context = context::Context; + async fn two_part(self, _: &mut Self::Context, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut Self::Context, s: String) -> String { s } - async fn baz(self, _: context::Context) {} + async fn baz(self, _: &mut Self::Context) {} } } @@ -37,20 +38,21 @@ fn raw_idents() { } impl r#trait for () { + type Context = context::Context; async fn r#await( self, - _: context::Context, + _: &mut Self::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut Self::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) {} + async fn r#async(self, _: &mut Self::Context) {} } } @@ -64,7 +66,8 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: context::Context) {} + type Context = context::Context; + async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index d66261d19..c96014eea 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder}; use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; @@ -108,7 +109,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = context::Context; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -134,7 +136,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(context::current(), "friend".into()).await? + client.hello(&mut context::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 5f5386785..7fe32bfa7 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -3,9 +3,10 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::context::Context; +use tarpc::{context}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -21,7 +22,8 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + type Context = context::Context; + async fn ping(self, _: &mut Self::Context) {} } #[tokio::main] @@ -52,7 +54,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index d61f68c48..6c0099a97 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" /// port. Because both publishers and subscribers initiate their connections to the PubSub @@ -40,6 +41,7 @@ use futures::{ }; use opentelemetry::trace::TracerProvider as _; use publisher::Publisher as _; +use serde::de::DeserializeOwned; use std::{ collections::HashMap, error::Error, @@ -47,7 +49,9 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; +use serde::Serialize; use subscriber::Subscriber as _; +use tarpc::context::{ExtractContext}; use tarpc::{ client, context, serde_transport::tcp, @@ -80,11 +84,12 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + type Context = context::Context; + async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut Self::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -132,10 +137,20 @@ struct Subscription { topics: Vec, } -#[derive(Clone, Debug)] -struct Publisher { +#[derive(Debug)] +struct Publisher { clients: Arc>>, - subscriptions: Arc>>>, + subscriptions: + Arc>>>>, +} + +impl Clone for Publisher { + fn clone(&self) -> Self { + Publisher { + clients: self.clients.clone(), + subscriptions: self.subscriptions.clone(), + } + } } struct PublisherAddrs { @@ -147,7 +162,17 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } -impl Publisher { +// TODO: Remove serde bounds here +impl Publisher +where + ClientCtx: ExtractContext + + From + + Serialize + + DeserializeOwned + + Send + + Sync + + 'static, +{ async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -183,7 +208,6 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let tarpc::client::NewClient { client: subscriber, dispatch, @@ -207,10 +231,11 @@ impl Publisher { async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(context::current()).await { + if let Ok(topics) = subscriber.topics(&mut ClientCtx::from(context::current())).await + { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -262,8 +287,12 @@ impl Publisher { } } -impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { +impl publisher::Publisher for Publisher +where + ClientCtx: ExtractContext + From + Send + Sync + 'static, +{ + type Context = ClientCtx; + async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, @@ -271,7 +300,10 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); for client in subscribers.values_mut() { - publications.push(client.receive(context::current(), topic.clone(), message.clone())); + publications.push(async { + let mut context = ClientCtx::from(context::current()); + client.receive(&mut context, topic.clone(), message.clone(), ).await + }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until // subscribers ack. Of course, a lot would be different in a real pubsub :) @@ -316,7 +348,7 @@ pub fn init_tracing( async fn main() -> anyhow::Result<()> { let tracer_provider = init_tracing("Pub/Sub")?; - let addrs = Publisher { + let addrs = Publisher:: { clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } @@ -342,29 +374,21 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(context::current(), "calculus".into(), "sqrt(2)".into()) + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) .await?; publisher - .publish( - context::current(), - "cool shorts".into(), - "hello to all".into(), - ) + .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()) .await?; publisher - .publish(context::current(), "history".into(), "napoleon".to_string()) + .publish(&mut context::current(), "history".into(), "napoleon".to_string()) .await?; drop(_subscriber0); publisher - .publish( - context::current(), - "cool shorts".into(), - "hello to who?".into(), - ) + .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ) .await?; tracer_provider.shutdown()?; diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c328bd884..f8f298921 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -3,12 +3,10 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::{ - client, context, - server::{self, Channel}, -}; +use tarpc::{client, context, server::{self, Channel}}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -23,7 +21,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = context::Context; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } } @@ -46,7 +45,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), "Stim".to_string()).await?; + let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 968f76c17..0ba8f2581 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; use rustls_pemfile::certs; @@ -18,7 +19,7 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::Context; +use tarpc::context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -33,7 +34,8 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) -> String { + type Context = context::Context; + async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } } @@ -146,7 +148,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 79a7026c0..0789d0a43 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -3,7 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - +#![deny(warnings, unused, dead_code)] #![allow(clippy::type_complexity)] use crate::{ @@ -12,6 +12,7 @@ use crate::{ }; use futures::{future, prelude::*}; use opentelemetry::trace::TracerProvider as _; +use std::marker::PhantomData; use std::{ io, sync::{ @@ -19,6 +20,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; +use tarpc::context::{ExtractContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -56,23 +58,27 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + type Context = context::Context; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, + ghost: PhantomData, } -impl DoubleService for DoubleServer +impl DoubleService for DoubleServer where - Stub: AddStub + Clone + Send + Sync + 'static, + Stub: AddStub + Clone + Send + Sync + 'static, + ClientCtx: From + Send + Sync + 'static, { - async fn double(self, _: context::Context, x: i32) -> Result { + type Context = context::Context; + async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(context::current(), x, x) + .add(&mut ClientCtx::from(context::current()), x, x) .await .map_err(|e| e.to_string()) } @@ -123,15 +129,16 @@ where Ok((listener, addr)) } -fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], +fn make_stub( + backends: [impl Transport>, Response> + Send + Sync + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, - load_balance::RoundRobin, Resp>>, + load_balance::RoundRobin, Resp, ClientCtx>>, > where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static, { let stub = load_balance::RoundRobin::new( backends @@ -186,16 +193,15 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer { add_client }.serve(); + let server = DoubleServer::<_, context::Context> { add_client, ghost: PhantomData }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!("{:?}", double_client.double(&mut context::current(), 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 3cf9ff07a..74031d969 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -18,6 +18,7 @@ use crate::{ use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use pin_project::pin_project; +use std::marker::PhantomData; use std::{ any::Any, convert::TryFrom, @@ -31,6 +32,7 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; +use crate::context::ExtractContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -95,27 +97,32 @@ const _CHECK_USIZE: () = assert!( /// Handles communication from the client to request dispatch. #[derive(Debug)] -pub struct Channel { +pub struct Channel { to_dispatch: mpsc::Sender>, /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. next_request_id: Arc, + + ///TODO: Document + ghost: PhantomData, } -impl Clone for Channel { +impl Clone for Channel { fn clone(&self) -> Self { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), + ghost: PhantomData, } } } -impl Channel +impl Channel where Req: RequestName, + ClientCtx: ExtractContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -124,19 +131,20 @@ where skip(self, ctx, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.deadline.time_until()), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.extract().deadline.time_until()), otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, mut ctx: context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut ClientCtx, request: Req) -> Result { let span = Span::current(); - ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + let mut shared_context = ctx.extract(); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled child context." ); - ctx.trace_context.new_child() + shared_context.trace_context.new_child() }); - span.record("rpc.trace_id", tracing::field::display(ctx.trace_id())); + span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id()), ); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -153,7 +161,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx, + ctx: shared_context, span, request_id, request, @@ -161,14 +169,19 @@ where }) .await .map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?; - response_guard.response().await + + let (response_ctx, r) = response_guard.response().await?; + + ctx.update(response_ctx); + + Ok(r) } } /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -196,7 +209,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result { + async fn response(mut self) -> Result<(context::Context, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -234,12 +247,12 @@ impl Drop for ResponseGuard<'_, Resp> { /// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the /// channel. -pub fn new( +pub fn new( config: Config, transport: C, -) -> NewClient, RequestDispatch> +) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -249,6 +262,7 @@ where to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData, }, dispatch: RequestDispatch { config, @@ -257,6 +271,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, + ghost: PhantomData, }, } } @@ -266,7 +281,7 @@ where #[must_use] #[pin_project()] #[derive(Debug)] -pub struct RequestDispatch { +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. #[pin] transport: Fuse, @@ -275,7 +290,7 @@ pub struct RequestDispatch { /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, /// Produces errors that can be sent in response to any unprocessed requests at the time @@ -283,15 +298,18 @@ pub struct RequestDispatch { /// RequestDispatch::poll, which relies on downcasting the Any to a concrete error type /// determined within the poll function. terminal_error: Option>, + + ghost: PhantomData, } -impl RequestDispatch +impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -308,7 +326,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -457,7 +475,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -510,16 +528,18 @@ where // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. + + let trace_context = ctx.trace_context; + let deadline = ctx.deadline; + let request = ClientMessage::Request(Request { id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx.into(), //TODO: <-- This will actually initialize an empty client context, and the transport will never see the original }); + self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -541,14 +561,14 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { + let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { Some(triple) => triple, None => return Poll::Ready(None), }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { - trace_context: context.trace_context, + trace_context, request_id, }; self.start_send(cancel) @@ -558,10 +578,13 @@ where } /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server), + response + .message + .map_err(RpcError::Server) + .map(|m| (response.context.extract(), m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -636,9 +659,10 @@ where } } -impl Future for RequestDispatch +impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From, { type Output = Result<(), ChannelError>; @@ -669,11 +693,12 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { + ///TODO: this should be a &mut ClientCtx pub ctx: context::Context, pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -684,8 +709,8 @@ mod tests { use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, - context::{self, current}, - transport::{self, channel::UnboundedChannel}, + context, + transport::{self, channel::UnboundedChannel} }; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; @@ -708,23 +733,26 @@ mod tests { #[tokio::test] async fn response_completes_request_future() { - let (mut dispatch, mut _channel, mut server_channel) = set_up(); + let (mut dispatch, _channel, mut server_channel) = set_up(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); + let context = context::current(); + dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) .unwrap(); server_channel .send(Response { request_id: 0, + context: context::current(), message: Ok("Resp".into()), }) .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok((_, resp))) if resp == "Resp"); } #[tokio::test] @@ -746,11 +774,8 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok(Response { - request_id: 0, - message: Ok("well done"), - })) - .unwrap(); + tx.send(Ok((context::current(), "well done"))) + .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -768,7 +793,7 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -798,6 +823,7 @@ mod tests { &mut server_channel, Response { request_id: 0, + context: context::current(), message: Ok("hello".into()), }, ) @@ -808,7 +834,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -824,7 +850,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -845,7 +871,7 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -861,7 +887,7 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -878,17 +904,17 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send - let resp = channel.call(current(), "hi".to_string()).await; + let resp = channel.call(&mut context::current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; @@ -911,7 +937,7 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; assert_eq!( @@ -928,7 +954,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -938,7 +964,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -948,7 +974,7 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -957,34 +983,36 @@ mod tests { } /// Sets up a RequestDispatch with a transport that always errors. - fn set_up_always_err( + fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, - Channel, + Pin>>>, + Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { + let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); + let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData, }); let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData, }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) } - struct AlwaysErrorTransport(TransportError, PhantomData); + struct AlwaysErrorTransport(TransportError, PhantomData<(I, ClientCtx)>); #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] enum TransportError { @@ -1001,7 +1029,7 @@ mod tests { } } - impl Sink for AlwaysErrorTransport { + impl Sink for AlwaysErrorTransport { type Error = TransportError; fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { match self.0 { @@ -1033,8 +1061,8 @@ mod tests { } } - impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; + impl Stream for AlwaysErrorTransport { + type Item = Result, TransportError>; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { if matches!(self.0, TransportError::Read) { Poll::Ready(Some(Err(self.0))) @@ -1044,18 +1072,19 @@ mod tests { } } - fn set_up() -> ( + fn set_up() -> ( Pin< Box< RequestDispatch< String, String, - UnboundedChannel, ClientMessage>, + ClientCtx, + UnboundedChannel, ClientMessage>, >, >, >, - Channel, - UnboundedChannel, Response>, + Channel, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1063,28 +1092,30 @@ mod tests { let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); - let dispatch = RequestDispatch:: { + let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData, }; let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData, }; (Box::pin(dispatch), channel, server_channel) } - async fn reserve_for_send<'a>( - channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + async fn reserve_for_send<'a, ClientCtx>( + channel: &'a mut Channel, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { let permit = channel.to_dispatch.reserve().await.unwrap(); |request| { @@ -1107,14 +1138,14 @@ mod tests { } } - async fn send_request<'a>( - channel: &'a mut Channel, + async fn send_request<'a, ClientCtx>( + channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { ctx: context::current(), span: Span::current(), @@ -1132,9 +1163,12 @@ mod tests { response_guard } - async fn send_response( - channel: &mut UnboundedChannel, Response>, - response: Response, + async fn send_response( + channel: &mut UnboundedChannel< + ClientMessage, + Response, + >, + response: Response, ) { channel.send(response).await.unwrap(); } diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 1776a74a0..d6424c564 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,15 +1,17 @@ use crate::{ - context, - util::{Compact, TimeUntil}, + context, trace, + util::{Compact, TimeUntil} }; use fnv::FnvHashMap; use std::{ collections::hash_map, task::{Context, Poll}, }; +use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; +use crate::client::RpcError; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -29,9 +31,9 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { - ctx: context::Context, + ctx: trace::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -56,13 +58,14 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::Context, + ctx: trace::Context, + deadline: Instant, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { - let timeout = ctx.deadline.time_until(); + let timeout = deadline.time_until(); let deadline_key = self.deadlines.insert(request_id, timeout); vacant.insert(RequestData { ctx, @@ -76,8 +79,8 @@ impl InFlightRequests { } } - /// Removes a request without aborting. Returns true iff the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option { + /// Removes a request without aborting. Returns true if the request was found. + pub fn complete_request(&mut self, request_id: u64, result: Result<(context::Context, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -95,7 +98,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Res + 'a, + mut result: impl FnMut() -> Result<(context::Context, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -106,7 +109,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(trace::Context, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -121,7 +124,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Res, + expired_error: impl Fn() -> Result<(context::Context, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 85746b7f2..5e473566c 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -4,6 +4,7 @@ use crate::{ RequestName, client::{Channel, RpcError}, context, + context::ExtractContext, server::Serve, }; @@ -23,19 +24,24 @@ pub trait Stub { /// The service response type. type Resp; + ///TODO: document + type ClientCtx; + /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) + async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) -> Result; } -impl Stub for Channel +impl Stub for Channel where Req: RequestName, + ClientCtx: ExtractContext, { type Req = Req; type Resp = Resp; + type ClientCtx = ClientCtx; - async fn call(&self, ctx: context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut Self::ClientCtx, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -46,7 +52,8 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { + type ClientCtx = S::ServerCtx; + async fn call(&self, ctx: &mut Self::ClientCtx, req: Self::Req) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index d28a3c137..eb605ecf9 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -7,7 +7,6 @@ pub use round_robin::RoundRobin; mod round_robin { use crate::{ client::{RpcError, stub}, - context, }; use cycle::AtomicCycle; @@ -17,10 +16,11 @@ mod round_robin { { type Req = Stub::Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -99,8 +99,7 @@ mod round_robin { /// the same stub. mod consistent_hash { use crate::{ - client::{RpcError, stub}, - context, + client::{RpcError, stub} }; use std::{ collections::hash_map::RandomState, @@ -116,10 +115,11 @@ mod consistent_hash { { type Req = Stub::Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(context::current(), 'a').await?; + let resp = stub.call(&mut context::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(context::current(), 'b').await?; + let resp = stub.call(&mut context::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(context::current(), 'c').await?; + let resp = stub.call(&mut context::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 145c14c1f..171f8918e 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -1,16 +1,17 @@ use crate::{ RequestName, ServerError, client::{RpcError, stub::Stub}, - context, }; +use std::marker::PhantomData; use std::{collections::HashMap, hash::Hash, io}; /// A mock stub that returns user-specified responses. -pub struct Mock { +pub struct Mock { responses: HashMap, + ghost: PhantomData, } -impl Mock +impl Mock where Req: Eq + Hash, { @@ -18,19 +19,21 @@ where pub fn new(responses: [(Req, Resp); N]) -> Self { Self { responses: HashMap::from(responses), + ghost: PhantomData, } } } -impl Stub for Mock +impl Stub for Mock where Req: Eq + Hash + RequestName, Resp: Clone, { type Req = Req; type Resp = Resp; + type ClientCtx = ServerCtx; - async fn call(&self, _: context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut Self::ClientCtx, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index a07b05fc5..5499f60e4 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -3,7 +3,6 @@ use crate::{ RequestName, client::{RpcError, stub}, - context, }; use std::sync::Arc; @@ -15,10 +14,11 @@ where { type Req = Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 8e77cf223..d4a6611e0 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -21,8 +21,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] -#[non_exhaustive] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request @@ -38,6 +37,27 @@ pub struct Context { pub trace_context: trace::Context, } +///TODO +pub trait ExtractContext { + ///TODO + fn extract(&self) -> Ctx; + ///TODO + fn update(&mut self, value: Ctx); +} + +impl ExtractContext for T +where + T: Clone, +{ + fn extract(&self) -> T { + self.clone() + } + + fn update(&mut self, value: T) { + *self = value + } +} + #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 7e1944305..76d9a1815 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,8 +124,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { +//! type Context = context::Context; //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -157,8 +158,10 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # type Context = context::Context; +//! # +//! # // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -168,7 +171,6 @@ //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); -//! //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) @@ -179,12 +181,13 @@ //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. -//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn(); +//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let hello = client.hello(context::current(), "Stim".to_string()).await?; +//! let mut context = context::current(); +//! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); //! @@ -197,7 +200,7 @@ //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -#![deny(missing_docs)] +#![deny(missing_docs, warnings, unused, dead_code)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -250,17 +253,17 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use std::{any::Any, error::Error, io, sync::Arc}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] -pub enum ClientMessage { +pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which /// the server sends back to the client. - Request(Request), + Request(Request), /// A command to cancel an in-flight request, automatically sent by the client when a response /// future is dropped. /// @@ -279,15 +282,15 @@ pub enum ClientMessage { } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Request { +pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::Context, + pub context: Ctx, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. - pub message: T, + pub message: Req, } /// Implemented by the request types generated by tarpc::service. @@ -360,13 +363,14 @@ impl RequestName for u64 { /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Response { +pub struct Response { /// The ID of the request being responded to. pub request_id: u64, + /// Trace context, deadline, and other cross-cutting concerns. + pub context: Ctx, /// The response body, or an error if the request failed. pub message: Result, } - /// An error indicating the server aborted the request early, e.g., due to request throttling. #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] #[error("{kind:?}: {detail}")] @@ -489,14 +493,6 @@ impl ServerError { Self { kind, detail } } } - -impl Request { - /// Returns the deadline for this request. - pub fn deadline(&self) -> &Instant { - &self.context.deadline - } -} - #[test] fn test_channel_any_casts() { use assert_matches::assert_matches; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index da3b3ae21..7e08db475 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -27,6 +27,7 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; +use crate::context::ExtractContext; mod in_flight_requests; pub mod request_hook; @@ -58,9 +59,10 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { BaseChannel::new(self, transport) } @@ -69,6 +71,9 @@ impl Config { /// Equivalent to a `FnOnce(Req) -> impl Future`. #[allow(async_fn_in_trait)] pub trait Serve { + ///TODO document + type ServerCtx; + /// Type of request. type Req: RequestName; @@ -76,17 +81,17 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve(self, ctx: &mut Self::ServerCtx, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. #[derive(Debug)] -pub struct ServeFn { +pub struct ServeFn { f: F, - data: PhantomData Resp>, + data: PhantomData<(Req, Resp, ServerCtx)>, } -impl Clone for ServeFn +impl Clone for ServeFn where F: Clone, { @@ -98,14 +103,14 @@ where } } -impl Copy for ServeFn where F: Copy {} +impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. + for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -113,16 +118,20 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. + for<'a> F: FnOnce( + &'a mut ServerCtx, + Req, + ) -> Pin> + 'a + Send>>, { + type ServerCtx = ServerCtx; type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -138,7 +147,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -151,12 +160,13 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, + ghost: PhantomData<(Req, Resp, ServeCtx)>, } -impl BaseChannel +impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -202,28 +212,29 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, - ) -> Result, AlreadyExistsError> { + request: Request, + ) -> Result, AlreadyExistsError> { + let mut shared_context = request.context.extract(); let span = info_span!( "RPC", - rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()), + rpc.trace_id = %shared_context.trace_id(), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + shared_context.deadline.time_until()), otel.kind = "server", otel.name = tracing::field::Empty, ); - span.set_context(&request.context); - request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + span.set_context(&shared_context); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled \ child context." ); - request.context.trace_context.new_child() + shared_context.trace_context.new_child() }); let entered = span.enter(); tracing::debug!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( request.id, - request.context.deadline, + shared_context.deadline, span.clone(), ); match start { @@ -248,7 +259,7 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -256,9 +267,9 @@ impl fmt::Debug for BaseChannel { /// A request tracked by a [`Channel`]. #[derive(Debug)] -pub struct TrackedRequest { +pub struct TrackedRequest { /// The request sent by the client. - pub request: Request, + pub request: Request, /// A registration to abort a future when the [`Channel`] that produced this request stops /// tracking it. pub abort_registration: AbortRegistration, @@ -295,7 +306,7 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport::Resp>, TrackedRequest<::Req>>, + Self: Transport::Resp>, TrackedRequest::Req>>, { /// Type of request item. type Req; @@ -305,6 +316,8 @@ where /// The wrapped transport. type Transport; + ///TODO document + type ServerCtx; /// Configuration of the channel. fn config(&self) -> &Config; @@ -343,6 +356,7 @@ where /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, @@ -360,10 +374,11 @@ where /// let mut requests = server.requests(); /// tokio::spawn(async move { /// while let Some(Ok(request)) = requests.next().await { - /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` fn requests(self) -> Requests @@ -386,7 +401,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -399,12 +414,13 @@ where /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( - /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) })) + /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); + /// }.boxed())); + /// let mut context = context::current(); /// assert_eq!( - /// client.call(context::current(), 1).await.unwrap(), + /// client.call(&mut context, 1).await.unwrap(), /// 2); /// } /// ``` @@ -412,17 +428,18 @@ where where Self: Sized, Self::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.requests().execute(serve) } } -impl Stream for BaseChannel +impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { - type Item = Result, ChannelError>; + type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { #[derive(Clone, Copy, Debug)] @@ -525,10 +542,11 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, + ServerCtx: ExtractContext, { type Error = ChannelError; @@ -539,7 +557,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -572,19 +590,21 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { type Req = Req; type Resp = Resp; type Transport = T; + type ServerCtx = ServerCtx; fn config(&self) -> &Config { &self.config @@ -609,9 +629,9 @@ where #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender>, } impl Requests @@ -631,14 +651,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, @@ -703,7 +723,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -736,7 +756,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -748,17 +768,18 @@ where /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( - /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// requests.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// }.boxed())); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub fn execute(self, serve: S) -> impl Stream> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.take_while(|result| { if let Err(e) = result { @@ -805,17 +826,17 @@ impl Drop for ResponseGuard { /// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will /// be sent to the Channel to clean up associated request state. #[derive(Debug)] -pub struct InFlightRequest { - request: Request, +pub struct InFlightRequest { + request: Request, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender>, } -impl InFlightRequest { +impl InFlightRequest { /// Returns a reference to the request. - pub fn get(&self) -> &Request { + pub fn get(&self) -> &Request { &self.request } @@ -838,6 +859,7 @@ impl InFlightRequest { /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, @@ -855,18 +877,18 @@ impl InFlightRequest { /// tokio::spawn(async move { /// let mut requests = server.requests(); /// while let Some(Ok(in_flight_request)) = requests.next().await { - /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } - /// /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` /// pub async fn execute(self, serve: S) where Req: RequestName, - S: Serve, + S: Serve, { let Self { response_tx, @@ -875,7 +897,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -883,10 +905,11 @@ impl InFlightRequest { span.record("otel.name", message.name()); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, + context, message, }; let _ = response_tx.send(response).await; @@ -914,7 +937,7 @@ impl Stream for Requests where C: Channel, { - type Item = Result, C::Error>; + type Item = Result, C::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -960,6 +983,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; + use crate::context::{ExtractContext}; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -979,22 +1003,16 @@ mod tests { }; fn test_channel() -> ( - Pin, Response>>>>, - UnboundedChannel, ClientMessage>, + Pin, Response>, context::Context>>>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) } fn test_requests() -> ( - Pin< - Box< - Requests< - BaseChannel, Response>>, - >, - >, - >, - UnboundedChannel, ClientMessage>, + Pin, Response>, context::Context>>>>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1006,14 +1024,8 @@ mod tests { fn test_bounded_requests( capacity: usize, ) -> ( - Pin< - Box< - Requests< - BaseChannel, Response>>, - >, - >, - >, - channel::Channel, ClientMessage>, + Pin, Response>, context::Context>>>>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1023,7 +1035,7 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { context: context::current(), id: 0, @@ -1039,20 +1051,25 @@ mod tests { #[tokio::test] async fn test_serve() { - let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + let serve = serve(|_, i| async move { Ok(i) }.boxed()); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline { + impl BeforeRequest for SetDeadline + where + ServerCtx: ExtractContext, + { async fn before( &mut self, - ctx: &mut context::Context, - _: &Req, + ctx: &mut ServerCtx, + _: &Req ) -> Result<(), ServerError> { - ctx.deadline = self.0; + let mut inner = ctx.extract(); + inner.deadline = self.0; + ctx.update(inner); Ok(()) } } @@ -1060,14 +1077,14 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: context::Context, i| async move { + let serve = serve(move |ctx: &mut context::Context, i| async move { assert_eq!(ctx.deadline, some_time); Ok(i) - }); + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; + deadline_hook.serve(&mut ctx, 7).await?; Ok(()) } @@ -1085,37 +1102,37 @@ mod tests { } } } - impl BeforeRequest for PrintLatency { + impl BeforeRequest for PrintLatency { async fn before( &mut self, - _: &mut context::Context, - _: &Req, + _: &mut ServerCtx, + _: &Req ) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } } - impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::Context, _: &mut Result) { + impl AfterRequest for PrintLatency { + async fn after(&mut self, _: &mut ServerCtx, _: &mut Result) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { - let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1285,6 +1302,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1320,7 +1338,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) })).await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut() @@ -1350,6 +1368,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1361,6 +1380,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: context::current(), message: Ok(()), }) .await @@ -1401,6 +1421,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1421,6 +1442,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: context::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 428eb1a7d..568ae4495 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -33,7 +33,7 @@ where ) -> impl Stream>> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.map(move |channel| channel.execute(serve.clone())) } @@ -48,6 +48,7 @@ where /// # Example /// ```rust /// use tarpc::{ +/// ClientMessage, /// context, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, @@ -57,15 +58,17 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded(); +/// use tracing_subscriber::filter::FilterExt; +/// let (tx, rx) = transport::channel::unbounded(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// /// let incoming = stream::once(async move { /// BaseChannel::new(server::Config::default(), rx) -/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); +/// let mut context = context::current(); +/// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub async fn spawn_incoming( diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 64b644278..3ffdfac89 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -107,6 +107,7 @@ where type Req = C::Req; type Resp = C::Resp; type Transport = C::Transport; + type ServerCtx = C::ServerCtx; fn config(&self) -> &server::Config { self.inner.config() diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index bd9c103b0..527cb6f98 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -67,6 +67,7 @@ where self.as_mut().start_send(Response { request_id: r.request.id, + context: r.request.context, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: "server throttled the request.".into(), @@ -80,7 +81,7 @@ where } } -impl Sink::Resp>> for MaxRequests +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -92,7 +93,7 @@ where fn start_send( self: Pin<&mut Self>, - item: Response<::Resp>, + item: Response::Resp>, ) -> Result<(), Self::Error> { self.project().inner.start_send(item) } @@ -119,6 +120,7 @@ where type Req = ::Req; type Resp = ::Resp; type Transport = ::Transport; + type ServerCtx = ::ServerCtx; fn in_flight_requests(&self) -> usize { self.inner.in_flight_requests() @@ -188,6 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; + use crate::context; #[tokio::test] async fn throttler_in_flight_requests() { @@ -268,7 +271,7 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> { + -> PendingSink>, Response, > { PendingSink { ghost: PhantomData } } } @@ -293,10 +296,12 @@ mod tests { Poll::Pending } } - impl Channel for PendingSink>, Response> { + impl Channel + for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = context::Context; fn config(&self) -> &Config { unimplemented!() } @@ -326,16 +331,16 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(1), }) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 0); - assert_eq!( - throttler.inner.sink.front(), - Some(&Response { - request_id: 0, - message: Ok(1), - }) - ); + + let result = throttler.inner.sink.front(); + + assert_eq!(result.map(|r| r.request_id), Some(0)); + + assert_eq!(result.map(|r| &r.message), Some(&Ok(1))); } } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 66cf2878c..cce5998ee 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -43,11 +43,11 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// - /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) /// .before(|_ctx: &mut context::Context, req: &i32| { /// future::ready( /// if *req == 1 { @@ -58,12 +58,13 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` - fn before(self, hook: Hook) -> HookThenServe + fn before(self, hook: Hook) -> HookThenServe where - Hook: BeforeRequest, + Hook: BeforeRequest, Self: Sized, { HookThenServe::new(self, hook) @@ -80,7 +81,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// @@ -93,20 +94,20 @@ pub trait RequestHook: Serve { /// } else { /// Ok(i + 1) /// } - /// }) + /// }.boxed()) /// .after(|_ctx: &mut context::Context, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> ServeThenHook where - Hook: AfterRequest, + Hook: AfterRequest, Self: Sized, { ServeThenHook::new(self, hook) @@ -123,7 +124,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{ /// context, ServerError, /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} @@ -132,17 +133,17 @@ pub trait RequestHook: Serve { /// /// struct PrintLatency(Instant); /// - /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + /// impl BeforeRequest for PrintLatency { + /// async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } /// } /// - /// impl AfterRequest for PrintLatency { + /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::Context, + /// _: &mut ServerCtx, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -151,16 +152,17 @@ pub trait RequestHook: Serve { /// /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) - /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// }.boxed()).before_and_after(PrintLatency(Instant::now())); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( self, hook: Hook, - ) -> HookThenServeThenHook + ) -> HookThenServeThenHook where - Hook: BeforeRequest + AfterRequest, + Hook: BeforeRequest + AfterRequest, Self: Sized, { HookThenServeThenHook::new(self, hook) diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index b2ef9ccbd..1fa3cee51 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -6,24 +6,24 @@ //! Provides a hook that runs after request execution. -use crate::{ServerError, context, server::Serve}; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs after request execution. #[allow(async_fn_in_trait)] -pub trait AfterRequest { +pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result); } -impl AfterRequest for F +impl AfterRequest for F where - F: FnMut(&mut context::Context, &mut Result) -> Fut, + F: FnMut(&mut ServerCtx, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result) { self(ctx, resp).await } } @@ -52,21 +52,22 @@ impl Clone for ServeThenHook { impl Serve for ServeThenHook where Serv: Serve, - Hook: AfterRequest, + Hook: AfterRequest, { type Req = Serv::Req; type Resp = Serv::Resp; + type ServerCtx = Serv::ServerCtx; async fn serve( self, - mut ctx: context::Context, + ctx: &mut Serv::ServerCtx, req: Serv::Req, ) -> Result { let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index e72e28a42..13fc18509 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -6,12 +6,13 @@ //! Provides a hook that runs before request execution. -use crate::{ServerError, context, server::Serve}; +use crate::{ServerError, server::Serve}; use futures::prelude::*; +use std::marker::PhantomData; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] -pub trait BeforeRequest { +pub trait BeforeRequest { /// The function that is called before request execution. /// /// If this function returns an error, the request will not be executed and the error will be @@ -19,22 +20,22 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. -pub trait BeforeRequestList: BeforeRequest { +pub trait BeforeRequestList: BeforeRequest { /// The hook returned by `BeforeRequestList::then`. - type Then: BeforeRequest + type Then: BeforeRequest where - Next: BeforeRequest; + Next: BeforeRequest; /// Returns a hook that, when run, runs two hooks, first `self` and then `next`. - fn then>(self, next: Next) -> Self::Then; + fn then>(self, next: Next) -> Self::Then; /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::Context, &Req) -> Fut, + Next: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, >( self, @@ -47,53 +48,60 @@ pub trait BeforeRequestList: BeforeRequest { } /// The service fn returned by `BeforeRequestList::serving`. - type Serve>: Serve; + type Serve>: Serve; /// Runs the list of request hooks before execution of the given serve fn. /// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer. - fn serving>(self, serve: S) -> Self::Serve; + fn serving>(self, serve: S) -> Self::Serve; } -impl BeforeRequest for F +impl BeforeRequest for F where - F: FnMut(&mut context::Context, &Req) -> Fut, + F: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { self(ctx, req).await } } /// A Service function that runs a hook before request execution. -#[derive(Clone)] -pub struct HookThenServe { +pub struct HookThenServe { serve: Serv, hook: Hook, + ghost: PhantomData, } -impl HookThenServe { +impl Clone for HookThenServe { + fn clone(&self) -> Self { + Self::new(self.serve.clone(), self.hook.clone()) + } +} + +impl HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook } + Self { serve, hook, ghost: PhantomData } } } -impl Serve for HookThenServe +impl Serve for HookThenServe where - Serv: Serve, - Hook: BeforeRequest, + Serv: Serve, + Hook: BeforeRequest, { + type ServerCtx = ServerCtx; type Req = Serv::Req; type Resp = Serv::Resp; async fn serve( self, - mut ctx: context::Context, - req: Self::Req, + ctx: &mut ServerCtx, + req: Self::Req ) -> Result { let HookThenServe { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } @@ -103,7 +111,7 @@ where /// Example /// /// ```rust -/// use futures::{executor::block_on, future}; +/// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, /// BeforeRequest, BeforeRequestList}}}; /// use std::{cell::Cell, io}; @@ -120,8 +128,9 @@ where /// i.set(2); /// Ok(()) /// }) -/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); -/// let response = serve.clone().serve(context::current(), 1); +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); +/// let mut context = context::current(); +/// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); /// ``` @@ -137,10 +146,10 @@ pub struct BeforeRequestCons(First, Rest); #[derive(Clone, Copy)] pub struct BeforeRequestNil; -impl, Rest: BeforeRequest> BeforeRequest - for BeforeRequestCons +impl, Rest: BeforeRequest, ServerCtx> + BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -148,45 +157,45 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { +impl BeforeRequest for BeforeRequestNil { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { Ok(()) } } -impl, Rest: BeforeRequestList> BeforeRequestList - for BeforeRequestCons +impl, Rest: BeforeRequestList, ServerCtx> + BeforeRequestList for BeforeRequestCons { type Then = BeforeRequestCons> where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { let BeforeRequestCons(first, rest) = self; BeforeRequestCons(first, rest.then(next)) } - type Serve> = HookThenServe; + type Serve> = HookThenServe; - fn serving>(self, serve: S) -> Self::Serve { + fn serving>(self, serve: S) -> Self::Serve { HookThenServe::new(serve, self) } } -impl BeforeRequestList for BeforeRequestNil { +impl BeforeRequestList for BeforeRequestNil { type Then = BeforeRequestCons where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { BeforeRequestCons(next, BeforeRequestNil) } - type Serve> = S; + type Serve> = S; - fn serving>(self, serve: S) -> S { + fn serving>(self, serve: S) -> S { serve } } @@ -209,8 +218,9 @@ fn before_request_list() { i.set(2); Ok(()) }) - .serving(serve(|_ctx, i| async move { Ok(i + 1) })); - let response = serve.clone().serve(context::current(), 1); + .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); + let mut context = crate::context::current(); + let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 0761a7df3..934d82ad5 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -7,17 +7,17 @@ //! Provides a hook that runs both before and after request execution. use super::{after::AfterRequest, before::BeforeRequest}; -use crate::{RequestName, ServerError, context, server::Serve}; +use crate::{RequestName, ServerError, server::Serve}; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. -pub struct HookThenServeThenHook { +pub struct HookThenServeThenHook { serve: Serv, hook: Hook, - fns: PhantomData<(fn(Req), fn(Resp))>, + fns: PhantomData<(Req, Resp, ServerCtx)>, } -impl HookThenServeThenHook { +impl HookThenServeThenHook { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, @@ -27,7 +27,9 @@ impl HookThenServeThenHook { } } -impl Clone for HookThenServeThenHook { +impl Clone + for HookThenServeThenHook +{ fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -37,22 +39,23 @@ impl Clone for HookThenServeThenHook Serve for HookThenServeThenHook +impl Serve for HookThenServeThenHook where Req: RequestName, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, { type Req = Req; type Resp = Resp; + type ServerCtx = ServerCtx; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index db167c42e..39eabdaf5 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -38,14 +38,15 @@ where } } -impl Sink> for FakeChannel> { +impl Sink> for FakeChannel> +{ type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project().sink.poll_ready(cx).map_err(|e| match e {}) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -65,13 +66,14 @@ impl Sink> for FakeChannel> { } } -impl Channel for FakeChannel>, Response> +impl Channel for FakeChannel>, Response> where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = context::Context; fn config(&self) -> &Config { &self.config @@ -86,7 +88,8 @@ where } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response> +{ pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); @@ -111,7 +114,8 @@ impl FakeChannel>, Response> { } impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> { + pub fn default() -> FakeChannel>, Response> + { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 0268300dc..35c81fb1e 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -191,14 +191,14 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| async move { + .execute(serve(|_ctx: &mut context::Context, request: String| async move { request.parse::().map_err(|_| { ServerError::new( io::ErrorKind::InvalidInput, format!("{request:?} is not an int"), ) }) - })) + }.boxed())) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), @@ -206,8 +206,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client.call(&mut context::current(), "123".into()).await; + let response2 = client.call(&mut context::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d3237..812fc4ee7 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; - +use tarpc::context::Context; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e652cc8e8..4fe34df5f 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 18bb3a997..5a5b2f8e7 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,8 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + type Context = context::Context; + async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +54,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(context::current(), TestData::White) + .get_opposite_color(&mut context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 06542b43b..559521414 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -8,6 +8,7 @@ use tarpc::{ client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, + transport, transport::channel, }; use tokio::join; @@ -22,11 +23,12 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + type Context = context::Context; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -38,10 +40,10 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) + .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) .for_each(|response| response), ); - assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -55,7 +57,8 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + type Context = context::Context; + async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); } @@ -64,7 +67,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -73,7 +76,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let mut ctx = context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); - let _ = client.r#loop(ctx).await; + let _ = client.r#loop(&mut ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -112,9 +115,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(context::current(), "Tim".to_string()).await, + client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -145,8 +148,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(context::current(), 1, 2).await; - let res2 = client.hey(context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::current(), 1, 2).await; + let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -158,7 +161,8 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -169,12 +173,15 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context = context::current(); + let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); + + let req2 = client.add(&mut context, 3, 4); assert_matches!(req2.await, Ok(7)); + + let req3 = client.hey(&mut context, "Tim".to_string()); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) @@ -184,7 +191,8 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -195,9 +203,13 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context1 = context::current(); + let mut context2 = context::current(); + let mut context3 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); + let req3 = client.hey(&mut context3, "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -216,7 +228,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -225,8 +237,11 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); + let mut context1 = context::current(); + let mut context2 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); @@ -245,14 +260,16 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + type Context = context::Context; + async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 } } let (tx, rx) = channel::unbounded(); - tokio::spawn(async { + + tokio::task::spawn(async { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); @@ -262,8 +279,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(context::current()).await, Ok(1)); - assert_matches!(client.count(context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::current()).await, Ok(2)); Ok(()) }