From cc341bee43aa6c9d04d629d831cbf71299e37995 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sat, 12 Nov 2022 17:24:40 -0800 Subject: [PATCH 01/30] Add back the Client trait, renamed Stub. Also adds a Client stub trait alias for each generated service. Now that generic associated types are stable, it's almost possible to define a trait for Channel that works with async fns on stable. `impl trait in type aliases` is still necessary (and unstable), but we're getting closer. As a proof of concept, three more implementations of Stub are implemented; 1. A load balancer that round-robins requests between different stubs. 2. A load balancer that selects a stub based on a request hash, so that the same requests go to the same stubs. 3. A stub that retries requests based on a configurable policy. The "serde/rc" feature is added to the "full" feature because the Retry stub wraps the request in an Arc, so that the request is reusable for multiple calls. Server implementors commonly need to operate generically across all services or request types. For example, a server throttler may want to return errors telling clients to back off, which is not specific to any one service. --- plugins/src/lib.rs | 38 +++- tarpc/Cargo.toml | 2 +- tarpc/src/client.rs | 1 + tarpc/src/client/stub.rs | 56 +++++ tarpc/src/client/stub/load_balance.rs | 305 ++++++++++++++++++++++++++ tarpc/src/client/stub/mock.rs | 54 +++++ tarpc/src/client/stub/retry.rs | 75 +++++++ tarpc/src/lib.rs | 8 + tarpc/src/server.rs | 30 ++- 9 files changed, 563 insertions(+), 6 deletions(-) create mode 100644 tarpc/src/client/stub.rs create mode 100644 tarpc/src/client/stub/load_balance.rs create mode 100644 tarpc/src/client/stub/mock.rs create mode 100644 tarpc/src/client/stub/retry.rs diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 1b83c3247..003efc97b 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -276,6 +276,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { ServiceGenerator { response_fut_name, service_ident: ident, + client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), @@ -432,6 +433,7 @@ fn verify_types_were_provided( // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, + client_stub_ident: &'a Ident, server_ident: &'a Ident, response_fut_ident: &'a Ident, response_fut_name: &'a str, @@ -461,6 +463,9 @@ impl<'a> ServiceGenerator<'a> { future_types, return_types, service_ident, + client_stub_ident, + request_ident, + response_ident, server_ident, .. } = self; @@ -490,6 +495,7 @@ impl<'a> ServiceGenerator<'a> { }, ); + let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: Sized { @@ -501,6 +507,15 @@ impl<'a> ServiceGenerator<'a> { #server_ident { service: self } } } + + #[doc = #stub_doc] + #vis trait #client_stub_ident: tarpc::client::stub::Stub { + } + + impl #client_stub_ident for S + where S: tarpc::client::stub::Stub + { + } } } @@ -666,7 +681,7 @@ impl<'a> ServiceGenerator<'a> { #response_fut_ident::#camel_case_idents(resp) => std::pin::Pin::new_unchecked(resp) .poll(cx) - .map(#response_ident::#camel_case_idents), + .map(#response_ident::#camel_case_idents) )* } } @@ -689,7 +704,9 @@ impl<'a> ServiceGenerator<'a> { #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](std::future::Future). - #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); + #vis struct #client_ident< + Stub = tarpc::client::Channel<#request_ident, #response_ident> + >(Stub); } } @@ -719,6 +736,17 @@ impl<'a> ServiceGenerator<'a> { dispatch: new_client.dispatch, } } + } + + impl From for #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { + /// Returns a new client stub that sends requests over the given transport. + fn from(stub: Stub) -> Self { + #client_ident(stub) + } } } @@ -741,7 +769,11 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl #client_ident { + impl #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { #( #[allow(unused)] #( #method_attrs )* diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index ff0f59b9b..24da20eeb 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use." [features] default = [] -serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] +serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"] tokio1 = ["tokio/rt"] serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] serde-transport-json = ["tokio-serde/json"] diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index fdf28f693..3d4b1ed6a 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -7,6 +7,7 @@ //! Provides a client that connects to a server and sends multiplexed requests. mod in_flight_requests; +pub mod stub; use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs new file mode 100644 index 000000000..a8b72a20f --- /dev/null +++ b/tarpc/src/client/stub.rs @@ -0,0 +1,56 @@ +//! Provides a Stub trait, implemented by types that can call remote services. + +use crate::{ + client::{Channel, RpcError}, + context, +}; +use futures::prelude::*; + +pub mod load_balance; +pub mod retry; + +#[cfg(test)] +mod mock; + +/// A connection to a remote service. +/// Calls the service with requests of type `Req` and receives responses of type `Resp`. +pub trait Stub { + /// The service request type. + type Req; + + /// The service response type. + type Resp; + + /// The type of the future returned by `Stub::call`. + type RespFut<'a>: Future> + where + Self: 'a, + Self::Req: 'a, + Self::Resp: 'a; + + /// Calls a remote service. + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a>; +} + +impl Stub for Channel { + type Req = Req; + type Resp = Resp; + type RespFut<'a> = RespFut<'a, Req, Resp> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Req: 'a, Resp: 'a> = impl Future> + 'a; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs new file mode 100644 index 000000000..c9005a423 --- /dev/null +++ b/tarpc/src/client/stub/load_balance.rs @@ -0,0 +1,305 @@ +//! Provides load-balancing [Stubs](crate::client::stub::Stub). + +pub use consistent_hash::ConsistentHash; +pub use round_robin::RoundRobin; + +/// Provides a stub that load-balances with a simple round-robin strategy. +mod round_robin { + use crate::{ + client::{stub, RpcError}, + context, + }; + use cycle::AtomicCycle; + use futures::prelude::*; + + impl stub::Stub for RoundRobin + where + Stub: stub::Stub, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct RoundRobin { + stubs: AtomicCycle, + } + + impl RoundRobin + where + Stub: stub::Stub, + { + /// Returns a new RoundRobin stub. + pub fn new(stubs: Vec) -> Self { + Self { + stubs: AtomicCycle::new(stubs), + } + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let next = self.stubs.next(); + next.call(ctx, request_name, request).await + } + } + + mod cycle { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + /// Cycles endlessly and atomically over a collection of elements of type T. + #[derive(Clone, Debug)] + pub struct AtomicCycle(Arc>); + + #[derive(Debug)] + struct State { + elements: Vec, + next: AtomicUsize, + } + + impl AtomicCycle { + pub fn new(elements: Vec) -> Self { + Self(Arc::new(State { + elements, + next: Default::default(), + })) + } + + pub fn next(&self) -> &T { + self.0.next() + } + } + + impl State { + pub fn next(&self) -> &T { + let next = self.next.fetch_add(1, Ordering::Relaxed); + &self.elements[next % self.elements.len()] + } + } + + #[test] + fn test_cycle() { + let cycle = AtomicCycle::new(vec![1, 2, 3]); + assert_eq!(cycle.next(), &1); + assert_eq!(cycle.next(), &2); + assert_eq!(cycle.next(), &3); + assert_eq!(cycle.next(), &1); + } + } +} + +/// Provides a stub that load-balances with a consistent hashing strategy. +/// +/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use +/// the same stub. +mod consistent_hash { + use crate::{ + client::{stub, RpcError}, + context, + }; + use futures::prelude::*; + use std::{ + collections::hash_map::RandomState, + hash::{BuildHasher, Hash, Hasher}, + num::TryFromIntError, + }; + + impl stub::Stub for ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct ConsistentHash { + stubs: Vec, + stubs_len: u64, + hasher: S, + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn new(stubs: Vec) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher: RandomState::new(), + }) + } + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + S: BuildHasher, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn with_hasher(stubs: Vec, hasher: S) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher, + }) + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect( + "invariant broken: stubs_len is not larger than a usize, \ + so the hash modulo stubs_len should always fit in a usize", + ); + let next = &self.stubs[index]; + next.call(ctx, request_name, request).await + } + + fn hash_request(&self, req: &Stub::Req) -> u64 { + let mut hasher = self.hasher.build_hasher(); + req.hash(&mut hasher); + hasher.finish() + } + } + + #[cfg(test)] + mod tests { + use super::ConsistentHash; + use crate::{client::stub::mock::Mock, context}; + use std::{ + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + rc::Rc, + }; + + #[tokio::test] + async fn test() -> anyhow::Result<()> { + let stub = ConsistentHash::with_hasher( + vec![ + // For easier reading of the assertions made in this test, each Mock's response + // value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 % + // 3 = 1, etc. + Mock::new([('a', 3), ('b', 3), ('c', 3)]), + Mock::new([('a', 1), ('b', 1), ('c', 1)]), + Mock::new([('a', 2), ('b', 2), ('c', 2)]), + ], + FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]), + )?; + + for _ in 0..2 { + let resp = stub.call(context::current(), "", 'a').await?; + assert_eq!(resp, 1); + + let resp = stub.call(context::current(), "", 'b').await?; + assert_eq!(resp, 2); + + let resp = stub.call(context::current(), "", 'c').await?; + assert_eq!(resp, 3); + } + + Ok(()) + } + + struct HashRecorder(Vec); + impl Hasher for HashRecorder { + fn write(&mut self, bytes: &[u8]) { + self.0 = Vec::from(bytes); + } + fn finish(&self) -> u64 { + 0 + } + } + + struct FakeHasherBuilder { + recorded_hashes: Rc, u64>>, + } + + struct FakeHasher { + recorded_hashes: Rc, u64>>, + output: u64, + } + + impl BuildHasher for FakeHasherBuilder { + type Hasher = FakeHasher; + + fn build_hasher(&self) -> Self::Hasher { + FakeHasher { + recorded_hashes: self.recorded_hashes.clone(), + output: 0, + } + } + } + + impl FakeHasherBuilder { + fn new(fake_hashes: [(T, u64); N]) -> Self { + let mut recorded_hashes = HashMap::new(); + for (to_hash, fake_hash) in fake_hashes { + let mut recorder = HashRecorder(vec![]); + to_hash.hash(&mut recorder); + recorded_hashes.insert(recorder.0, fake_hash); + } + Self { + recorded_hashes: Rc::new(recorded_hashes), + } + } + } + + impl Hasher for FakeHasher { + fn write(&mut self, bytes: &[u8]) { + if let Some(hash) = self.recorded_hashes.get(bytes) { + self.output = *hash; + } + } + fn finish(&self) -> u64 { + self.output + } + } + } +} diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs new file mode 100644 index 000000000..99a54422f --- /dev/null +++ b/tarpc/src/client/stub/mock.rs @@ -0,0 +1,54 @@ +use crate::{ + client::{stub::Stub, RpcError}, + context, ServerError, +}; +use futures::future; +use std::{collections::HashMap, hash::Hash, io}; + +/// A mock stub that returns user-specified responses. +pub struct Mock { + responses: HashMap, +} + +impl Mock +where + Req: Eq + Hash, +{ + /// Returns a new mock, mocking the specified (request, response) pairs. + pub fn new(responses: [(Req, Resp); N]) -> Self { + Self { + responses: HashMap::from(responses), + } + } +} + +impl Stub for Mock +where + Req: Eq + Hash, + Resp: Clone, +{ + type Req = Req; + type Resp = Resp; + type RespFut<'a> = future::Ready> + where Self: 'a; + + fn call<'a>( + &'a self, + _: context::Context, + _: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + future::ready( + self.responses + .get(&request) + .cloned() + .map(Ok) + .unwrap_or_else(|| { + Err(RpcError::Server(ServerError { + kind: io::ErrorKind::NotFound, + detail: "mock (request, response) entry not found".into(), + })) + }), + ) + } +} diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs new file mode 100644 index 000000000..46ad09685 --- /dev/null +++ b/tarpc/src/client/stub/retry.rs @@ -0,0 +1,75 @@ +//! Provides a stub that retries requests based on response contents.. + +use crate::{ + client::{stub, RpcError}, + context, +}; +use futures::prelude::*; +use std::sync::Arc; + +impl stub::Stub for Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + type Req = Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub, Self::Req, F> + where Self: 'a, + Self::Req: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> = + impl Future> + 'a; + +/// A Stub that retries requests based on response contents. +/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled. +#[derive(Clone, Debug)] +pub struct Retry { + should_retry: F, + stub: Stub, +} + +impl Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + /// Creates a new Retry stub that delegates calls to the underlying `stub`. + pub fn new(stub: Stub, should_retry: F) -> Self { + Self { stub, should_retry } + } + + async fn call<'a, 'b>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Result + where + Req: 'b, + { + let request = Arc::new(request); + for i in 1.. { + let result = self + .stub + .call(ctx, request_name, Arc::clone(&request)) + .await; + if (self.should_retry)(&result, i) { + tracing::trace!("Retrying on attempt {i}"); + continue; + } + return result; + } + unreachable!("Wow, that was a lot of attempts!"); + } +} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 891efdd9c..280da694e 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -200,6 +200,7 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. +#![feature(type_alias_impl_trait)] #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -383,6 +384,13 @@ pub struct ServerError { pub detail: String, } +impl ServerError { + /// Returns a new server error with `kind` and `detail`. + pub fn new(kind: io::ErrorKind, detail: String) -> ServerError { + Self { kind, detail } + } +} + impl Request { /// Returns the deadline for this request. pub fn deadline(&self) -> &SystemTime { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 745a3c678..24df8dc60 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, - trace, ClientMessage, Request, Response, Transport, + trace, ClientMessage, Request, Response, ServerError, Transport, }; use ::tokio::sync::mpsc; use futures::{ @@ -583,6 +583,10 @@ where span, response_guard, }| { + { + let _entered = span.enter(); + tracing::info!("BeginRequest"); + } InFlightRequest { request, abort_registration, @@ -701,6 +705,29 @@ impl InFlightRequest { &self.request } + /// Respond without executing a service function. Useful for early aborts (e.g. for throttling). + pub async fn respond(self, response: Result) { + let Self { + response_tx, + response_guard, + request: Request { id: request_id, .. }, + span, + .. + } = self; + let _entered = span.enter(); + tracing::info!("CompleteRequest"); + let response = Response { + request_id, + message: response, + }; + let _ = response_tx.send(response).await; + tracing::info!("BufferResponse"); + // Request processing has completed, meaning either the channel canceled the request or + // a request was sent back to the channel. Either way, the channel will clean up the + // request data, so the request does not need to be canceled. + mem::forget(response_guard); + } + /// Returns a [future](Future) that executes the request using the given [service /// function](Serve). The service function's output is automatically sent back to the [Channel] /// that yielded this request. The request will be executed in the scope of this request's @@ -738,7 +765,6 @@ impl InFlightRequest { span.record("otel.name", &method.unwrap_or("")); let _ = Abortable::new( async move { - tracing::info!("BeginRequest"); let response = serve.serve(context, message).await; tracing::info!("CompleteRequest"); let response = Response { From 55511084e8afc1939d6105fe0e34397e2063b30f Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Mon, 7 Nov 2022 18:37:58 -0800 Subject: [PATCH 02/30] Add request hooks to the Serve trait. This allows plugging in horizontal functionality, such as authorization, throttling, or latency recording, that should run before and/or after execution of every request, regardless of the request type. The tracing example is updated to show off both client stubs as well as server hooks. As part of this change, there were some changes to the Serve trait: 1. Serve's output type is now a Result.. Serve previously did not allow returning ServerErrors, which prevented using Serve for horizontal functionality like throttling or auth. Now, Serve's output type is Result, making Serve a more natural integration point for horizontal capabilities. 2. Serve's generic Request type changed to an associated type. The primary benefit of the generic type is that it allows one type to impl a trait multiple times (for example, u64 impls TryFrom, TryFrom ServiceGenerator<'a> { } = self; quote! { - impl tarpc::server::Serve<#request_ident> for #server_ident + impl tarpc::server::Serve for #server_ident where S: #service_ident { + type Req = #request_ident; type Resp = #response_ident; type Fut = #response_fut_ident; @@ -670,10 +671,10 @@ impl<'a> ServiceGenerator<'a> { quote! { impl std::future::Future for #response_fut_ident { - type Output = #response_ident; + type Output = Result<#response_ident, tarpc::ServerError>; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll<#response_ident> + -> std::task::Poll> { unsafe { match std::pin::Pin::get_unchecked_mut(self) { @@ -682,6 +683,7 @@ impl<'a> ServiceGenerator<'a> { std::pin::Pin::new_unchecked(resp) .poll(cx) .map(#response_ident::#camel_case_idents) + .map(Ok), )* } } diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 5b4b8fd5b..c2afb3bf1 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -4,13 +4,32 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::{add::Add as AddService, double::Double as DoubleService}; +#![feature(type_alias_impl_trait)] + +use crate::{ + add::{Add as AddService, AddStub}, + double::Double as DoubleService, +}; use futures::{future, prelude::*}; +use std::{ + io, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; use tarpc::{ - client, context, - server::{incoming::Incoming, BaseChannel}, + client::{ + self, + stub::{load_balance, retry}, + RpcError, + }, + context, serde_transport, + server::{incoming::Incoming, BaseChannel, Serve}, tokio_serde::formats::Json, + ClientMessage, Response, ServerError, Transport, }; +use tokio::net::TcpStream; use tracing_subscriber::prelude::*; pub mod add { @@ -40,12 +59,16 @@ impl AddService for AddServer { } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, } #[tarpc::server] -impl DoubleService for DoubleServer { +impl DoubleService for DoubleServer +where + Stub: AddStub + Clone + Send + Sync + 'static, + for<'a> Stub::RespFut<'a>: Send, +{ async fn double(self, _: context::Context, x: i32) -> Result { self.add_client .add(context::current(), x, x) @@ -70,22 +93,79 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> { Ok(()) } +async fn listen_on_random_port() -> anyhow::Result<( + impl Stream>>, + std::net::SocketAddr, +)> +where + Item: for<'de> serde::Deserialize<'de>, + SinkItem: serde::Serialize, +{ + let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) + .await? + .filter_map(|r| future::ready(r.ok())) + .take(1); + let addr = listener.get_ref().get_ref().local_addr(); + Ok((listener, addr)) +} + +fn make_stub( + backends: [impl Transport>, Response> + Send + Sync + 'static; N], +) -> retry::Retry< + impl Fn(&Result, u32) -> bool + Clone, + load_balance::RoundRobin, Resp>>, +> +where + Req: Send + Sync + 'static, + Resp: Send + Sync + 'static, +{ + let stub = load_balance::RoundRobin::new( + backends + .into_iter() + .map(|transport| tarpc::client::new(client::Config::default(), transport).spawn()) + .collect(), + ); + let stub = retry::Retry::new(stub, |resp, attempts| { + if let Err(e) = resp { + tracing::warn!("Got an error: {e:?}"); + attempts < 3 + } else { + false + } + }); + stub +} + #[tokio::main] async fn main() -> anyhow::Result<()> { init_tracing("tarpc_tracing_example")?; - let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) - .await? - .filter_map(|r| future::ready(r.ok())); - let addr = add_listener.get_ref().local_addr(); - let add_server = add_listener + let (add_listener1, addr1) = listen_on_random_port().await?; + let (add_listener2, addr2) = listen_on_random_port().await?; + let something_bad_happened = Arc::new(AtomicBool::new(false)); + let server = AddServer.serve().before(move |_: &mut _, _: &_| { + let something_bad_happened = something_bad_happened.clone(); + async move { + if something_bad_happened.fetch_xor(true, Ordering::Relaxed) { + Err(ServerError::new( + io::ErrorKind::NotFound, + "Gamma Ray!".into(), + )) + } else { + Ok(()) + } + } + }); + let add_server = add_listener1 + .chain(add_listener2) .map(BaseChannel::with_defaults) - .take(1) - .execute(AddServer.serve()); + .execute(server); tokio::spawn(add_server); - let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn(); + let add_client = add::AddClient::from(make_stub([ + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, + ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 24df8dc60..f5c5151ee 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -32,6 +32,7 @@ use std::{ use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; +pub mod request_hook; #[cfg(test)] mod testing; @@ -46,6 +47,10 @@ pub mod incoming; #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] pub mod tokio; +use request_hook::{ + AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, +}; + /// Settings that control the behavior of [channels](Channel). #[derive(Clone, Debug)] pub struct Config { @@ -74,32 +79,212 @@ impl Config { } /// Equivalent to a `FnOnce(Req) -> impl Future`. -pub trait Serve { +pub trait Serve { + /// Type of request. + type Req; + /// Type of response. type Resp; /// Type of response future. - type Fut: Future; + type Fut: Future>; + + /// Responds to a single request. + fn serve(self, ctx: context::Context, req: Self::Req) -> Self::Fut; /// Extracts a method name from the request. - fn method(&self, _request: &Req) -> Option<&'static str> { + fn method(&self, _request: &Self::Req) -> Option<&'static str> { None } - /// Responds to a single request. - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut; + /// Runs a hook before execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context. This could be used, for example, to enforce a + /// maximum deadline on all requests. + /// + /// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &RequestType) -> impl Future>` can + /// also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// .before(|_ctx: &mut context::Context, req: &i32| { + /// future::ready( + /// if *req == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("I don't like {req}"))) + /// } else { + /// Ok(()) + /// }) + /// }); + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn before(self, hook: Hook) -> BeforeRequestHook + where + Hook: BeforeRequest, + Self: Sized, + { + BeforeRequestHook::new(self, hook) + } + + /// Runs a hook after completion of a request. + /// + /// The hook can modify the request context and the response. + /// + /// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &mut Result) -> impl Future` + /// can also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve( + /// |_ctx, i| async move { + /// if i == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("{i} is the loneliest number"))) + /// } else { + /// Ok(i + 1) + /// } + /// }) + /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// if let Err(e) = resp { + /// eprintln!("server error: {e:?}"); + /// } + /// future::ready(()) + /// }); + /// + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn after(self, hook: Hook) -> AfterRequestHook + where + Hook: AfterRequest, + Self: Sized, + { + AfterRequestHook::new(self, hook) + } + + /// Runs a hook before and after execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context and the response. This could be used, for + /// example, to enforce a maximum deadline on all requests. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{ + /// context, ServerError, server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest}} + /// }; + /// use std::{io, time::Instant}; + /// + /// struct PrintLatency(Instant); + /// + /// impl BeforeRequest for PrintLatency { + /// type Fut<'a> = future::Ready> where Self: 'a, Req: 'a; + /// + /// fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + /// self.0 = Instant::now(); + /// future::ready(Ok(())) + /// } + /// } + /// + /// impl AfterRequest for PrintLatency { + /// type Fut<'a> = future::Ready<()> where Self:'a, Resp:'a; + /// + /// fn after<'a>( + /// &'a mut self, + /// _: &'a mut context::Context, + /// _: &'a mut Result, + /// ) -> Self::Fut<'a> { + /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); + /// future::ready(()) + /// } + /// } + /// + /// let serve = serve(|_ctx, i| async move { + /// Ok(i + 1) + /// }).before_and_after(PrintLatency(Instant::now())); + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_ok()); + /// ``` + fn before_and_after( + self, + hook: Hook, + ) -> BeforeAndAfterRequestHook + where + Hook: BeforeRequest + AfterRequest, + Self: Sized, + { + BeforeAndAfterRequestHook::new(self, hook) + } +} + +/// A Serve wrapper around a Fn. +#[derive(Debug)] +pub struct ServeFn { + f: F, + data: PhantomData Resp>, +} + +impl Clone for ServeFn +where + F: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + data: PhantomData, + } + } +} + +impl Copy for ServeFn where F: Copy {} + +/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. +pub fn serve(f: F) -> ServeFn +where + F: FnOnce(context::Context, Req) -> Fut, + Fut: Future>, +{ + ServeFn { + f, + data: PhantomData, + } } -impl Serve for F +impl Serve for ServeFn where F: FnOnce(context::Context, Req) -> Fut, - Fut: Future, + Fut: Future>, { + type Req = Req; type Resp = Resp; type Fut = Fut; fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { - self(ctx, req) + (self.f)(ctx, req) } } @@ -127,7 +312,7 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(Req, Resp)>, + ghost: PhantomData<(fn() -> Req, fn(Resp))>, } impl BaseChannel @@ -313,6 +498,34 @@ where /// This is a terminal operation. After calling `requests`, the channel cannot be retrieved, /// and the only way to complete requests is via [`Requests::execute`] or /// [`InFlightRequest::execute`]. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// let mut requests = server.requests(); + /// tokio::spawn(async move { + /// while let Some(Ok(request)) = requests.next().await { + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// } + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` fn requests(self) -> Requests where Self: Sized, @@ -329,12 +542,28 @@ where /// Runs the channel until completion by executing all requests using the given service /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's /// default executor. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// let channel = BaseChannel::new(server::Config::default(), rx); + /// tokio::spawn(channel.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> where Self: Sized, - S: Serve + Send + 'static, + S: Serve + Send + 'static, S::Fut: Send, Self::Req: Send + 'static, Self::Resp: Send + 'static, @@ -705,29 +934,6 @@ impl InFlightRequest { &self.request } - /// Respond without executing a service function. Useful for early aborts (e.g. for throttling). - pub async fn respond(self, response: Result) { - let Self { - response_tx, - response_guard, - request: Request { id: request_id, .. }, - span, - .. - } = self; - let _entered = span.enter(); - tracing::info!("CompleteRequest"); - let response = Response { - request_id, - message: response, - }; - let _ = response_tx.send(response).await; - tracing::info!("BufferResponse"); - // Request processing has completed, meaning either the channel canceled the request or - // a request was sent back to the channel. Either way, the channel will clean up the - // request data, so the request does not need to be canceled. - mem::forget(response_guard); - } - /// Returns a [future](Future) that executes the request using the given [service /// function](Serve). The service function's output is automatically sent back to the [Channel] /// that yielded this request. The request will be executed in the scope of this request's @@ -742,9 +948,39 @@ impl InFlightRequest { /// /// If the returned Future is dropped before completion, a cancellation message will be sent to /// the Channel to clean up associated request state. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// tokio::spawn(async move { + /// let mut requests = server.requests(); + /// while let Some(Ok(in_flight_request)) = requests.next().await { + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// } + /// + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + /// pub async fn execute(self, serve: S) where - S: Serve, + S: Serve, { let Self { response_tx, @@ -765,11 +1001,11 @@ impl InFlightRequest { span.record("otel.name", &method.unwrap_or("")); let _ = Abortable::new( async move { - let response = serve.serve(context, message).await; + let message = serve.serve(context, message).await; tracing::info!("CompleteRequest"); let response = Response { request_id, - message: Ok(response), + message, }; let _ = response_tx.send(response).await; tracing::info!("BufferResponse"); @@ -813,11 +1049,14 @@ where #[cfg(test)] mod tests { - use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests}; + use super::{ + in_flight_requests::AlreadyExistsError, serve, AfterRequest, BaseChannel, BeforeRequest, + Channel, Config, Requests, Serve, + }; use crate::{ context, trace, transport::channel::{self, UnboundedChannel}, - ClientMessage, Request, Response, + ClientMessage, Request, Response, ServerError, }; use assert_matches::assert_matches; use futures::{ @@ -826,7 +1065,12 @@ mod tests { Future, }; use futures_test::task::noop_context; - use std::{pin::Pin, task::Poll}; + use std::{ + io, + pin::Pin, + task::Poll, + time::{Duration, Instant, SystemTime}, + }; fn test_channel() -> ( Pin, Response>>>>, @@ -887,6 +1131,101 @@ mod tests { Abortable::new(pending(), abort_registration) } + #[tokio::test] + async fn test_serve() { + let serve = serve(|_, i| async move { Ok(i) }); + assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + } + + #[tokio::test] + async fn serve_before_mutates_context() -> anyhow::Result<()> { + struct SetDeadline(SystemTime); + type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; + impl BeforeRequest for SetDeadline { + type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>( + &'a mut self, + ctx: &'a mut context::Context, + _: &'a Req, + ) -> Self::Fut<'a> { + async move { + ctx.deadline = self.0; + Ok(()) + } + } + } + + let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); + let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); + + let serve = serve(move |ctx: context::Context, i| async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + }); + let deadline_hook = serve.before(SetDeadline(some_time)); + let mut ctx = context::current(); + ctx.deadline = some_other_time; + deadline_hook.serve(ctx, 7).await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_and_after() -> anyhow::Result<()> { + let _ = tracing_subscriber::fmt::try_init(); + + struct PrintLatency { + start: Instant, + } + impl PrintLatency { + fn new() -> Self { + Self { + start: Instant::now(), + } + } + } + type StartFut<'a, Req: 'a> = impl Future> + 'a; + type EndFut<'a, Resp: 'a> = impl Future + 'a; + impl BeforeRequest for PrintLatency { + type Fut<'a> = StartFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + async move { + self.start = Instant::now(); + Ok(()) + } + } + } + impl AfterRequest for PrintLatency { + type Fut<'a> = EndFut<'a, Resp> where Self: 'a, Resp: 'a; + fn after<'a>( + &'a mut self, + _: &'a mut context::Context, + _: &'a mut Result, + ) -> Self::Fut<'a> { + async move { + tracing::info!("Elapsed: {:?}", self.start.elapsed()); + } + } + } + + let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + serve + .before_and_after(PrintLatency::new()) + .serve(context::current(), 7) + .await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_error_aborts_request() -> anyhow::Result<()> { + let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + Err(ServerError::new(io::ErrorKind::Other, "oops".into())) + }); + let resp: Result = deadline_hook.serve(context::current(), 7).await; + assert_matches!(resp, Err(_)); + Ok(()) + } + #[tokio::test] async fn base_channel_start_send_duplicate_request_returns_error() { let (mut channel, _tx) = test_channel::<(), ()>(); @@ -1087,7 +1426,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {:?}", result), }; - request.execute(|_, _| async {}).await; + request.execute(serve(|_, _| async { Ok(()) })).await; assert!(requests .as_mut() .channel_pin_mut() diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 445fc3e89..931e87669 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -35,7 +35,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] fn execute(self, serve: S) -> TokioServerExecutor where - S: Serve, + S: Serve, { TokioServerExecutor::new(self, serve) } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs new file mode 100644 index 000000000..ef23d73b4 --- /dev/null +++ b/tarpc/src/server/request_hook.rs @@ -0,0 +1,22 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Hooks for horizontal functionality that can run either before or after a request is executed. + +/// A request hook that runs before a request is executed. +mod before; + +/// A request hook that runs after a request is completed. +mod after; + +/// A request hook that runs both before a request is executed and after it is completed. +mod before_and_after; + +pub use { + after::{AfterRequest, AfterRequestHook}, + before::{BeforeRequest, BeforeRequestHook}, + before_and_after::BeforeAndAfterRequestHook, +}; diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs new file mode 100644 index 000000000..a3803bade --- /dev/null +++ b/tarpc/src/server/request_hook/after.rs @@ -0,0 +1,89 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs after request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs after request execution. +pub trait AfterRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future + where + Self: 'a, + Resp: 'a; + + /// The function that is called after request execution. + /// + /// The hook can modify the request context and the response. + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a>; +} + +impl AfterRequest for F +where + F: FnMut(&mut context::Context, &mut Result) -> Fut, + Fut: Future, +{ + type Fut<'a> = Fut where Self: 'a, Resp: 'a; + + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a> { + self(ctx, resp) + } +} + +/// A Service function that runs a hook after request execution. +pub struct AfterRequestHook { + serve: Serv, + hook: Hook, +} + +impl AfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for AfterRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for AfterRequestHook +where + Serv: Serve, + Hook: AfterRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + type Fut = AfterRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Serv::Req) -> Self::Fut { + async move { + let AfterRequestHook { + serve, mut hook, .. + } = self; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp + } + } +} + +type AfterRequestHookFut> = + impl Future>; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs new file mode 100644 index 000000000..38ad54d01 --- /dev/null +++ b/tarpc/src/server/request_hook/before.rs @@ -0,0 +1,84 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs before request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs before request execution. +pub trait BeforeRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future> + where + Self: 'a, + Req: 'a; + + /// The function that is called before request execution. + /// + /// If this function returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// This function can also modify the request context. This could be used, for example, to + /// enforce a maximum deadline on all requests. + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a>; +} + +impl BeforeRequest for F +where + F: FnMut(&mut context::Context, &Req) -> Fut, + Fut: Future>, +{ + type Fut<'a> = Fut where Self: 'a, Req: 'a; + + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a> { + self(ctx, req) + } +} + +/// A Service function that runs a hook before request execution. +pub struct BeforeRequestHook { + serve: Serv, + hook: Hook, +} + +impl BeforeRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for BeforeRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for BeforeRequestHook +where + Serv: Serve, + Hook: BeforeRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + type Fut = BeforeRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Self::Req) -> Self::Fut { + let BeforeRequestHook { + serve, mut hook, .. + } = self; + async move { + hook.before(&mut ctx, &req).await?; + serve.serve(ctx, req).await + } + } +} + +type BeforeRequestHookFut> = + impl Future>; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs new file mode 100644 index 000000000..ca42460bc --- /dev/null +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -0,0 +1,70 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs both before and after request execution. + +use super::{after::AfterRequest, before::BeforeRequest}; +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; +use std::marker::PhantomData; + +/// A Service function that runs a hook both before and after request execution. +pub struct BeforeAndAfterRequestHook { + serve: Serv, + hook: Hook, + fns: PhantomData<(fn(Req), fn(Resp))>, +} + +impl BeforeAndAfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { + serve, + hook, + fns: PhantomData, + } + } +} + +impl Clone + for BeforeAndAfterRequestHook +{ + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + fns: PhantomData, + } + } +} + +impl Serve for BeforeAndAfterRequestHook +where + Serv: Serve, + Hook: BeforeRequest + AfterRequest, +{ + type Req = Req; + type Resp = Resp; + type Fut = BeforeAndAfterRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Req) -> Self::Fut { + async move { + let BeforeAndAfterRequestHook { + serve, mut hook, .. + } = self; + hook.before(&mut ctx, &req).await?; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp + } + } +} + +type BeforeAndAfterRequestHookFut< + Req, + Resp, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, +> = impl Future>; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs index a44e8469e..e9ad84221 100644 --- a/tarpc/src/server/tokio.rs +++ b/tarpc/src/server/tokio.rs @@ -55,9 +55,25 @@ where { /// Executes all requests using the given service function. Requests are handled concurrently /// by [spawning](::tokio::spawn) each handler on tokio's default executor. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// tokio::spawn(requests.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` pub fn execute(self, serve: S) -> TokioChannelExecutor where - S: Serve + Send + 'static, + S: Serve + Send + 'static, { TokioChannelExecutor { inner: self, serve } } @@ -69,7 +85,7 @@ where C: Channel + Send + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, + Se: Serve + Send + 'static + Clone, Se::Fut: Send, { type Output = (); @@ -88,7 +104,7 @@ where C: Channel + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, + S: Serve + Send + 'static + Clone, S::Fut: Send, { type Output = (); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 529ae8f58..7f3035d14 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -150,12 +150,14 @@ impl Sink for Channel { #[cfg(feature = "tokio1")] mod tests { use crate::{ - client, context, - server::{incoming::Incoming, BaseChannel}, + client::{self, RpcError}, + context, + server::{incoming::Incoming, serve, BaseChannel}, transport::{ self, channel::{Channel, UnboundedChannel}, }, + ServerError, }; use assert_matches::assert_matches; use futures::{prelude::*, stream}; @@ -177,25 +179,25 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(|_ctx, request: String| { - future::ready(request.parse::().map_err(|_| { - io::Error::new( + .execute(serve(|_ctx, request: String| async move { + request.parse::().map_err(|_| { + ServerError::new( io::ErrorKind::InvalidInput, format!("{request:?} is not an int"), ) - })) - }), + }) + })), ); let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "", "123".into()).await?; - let response2 = client.call(context::current(), "", "abc".into()).await?; + let response1 = client.call(context::current(), "", "123".into()).await; + let response2 = client.call(context::current(), "", "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); assert_matches!(response1, Ok(123)); - assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput); + assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput); Ok(()) } From 7f02cf7a871f18aa08a4e556e27cc646260913c3 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sat, 12 Nov 2022 16:32:27 -0800 Subject: [PATCH 03/30] Use rust nightly for Github workflows. While using unstable feature type_alias_impl_trait. --- .github/workflows/main.yml | 8 ++++---- hooks/pre-push | 10 +++++----- .../must_use_request_dispatch.stderr | 2 +- .../serde_transport/must_use_tcp_connect.stderr | 2 +- .../tarpc_server_missing_async.stderr | 16 ++++++++++------ .../tokio/must_use_channel_executor.stderr | 2 +- .../tokio/must_use_server_executor.stderr | 2 +- 7 files changed, 23 insertions(+), 19 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e94ca60ac..1ca595603 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,7 +21,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: nightly target: mipsel-unknown-linux-gnu override: true - uses: actions-rs/cargo@v1 @@ -45,7 +45,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: nightly override: true - uses: actions-rs/cargo@v1 with: @@ -83,7 +83,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: nightly override: true - run: rustup component add rustfmt - uses: actions-rs/cargo@v1 @@ -103,7 +103,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: nightly override: true - run: rustup component add clippy - uses: actions-rs/cargo@v1 diff --git a/hooks/pre-push b/hooks/pre-push index 7b527e0a8..1e5500d63 100755 --- a/hooks/pre-push +++ b/hooks/pre-push @@ -84,12 +84,12 @@ command -v rustup &>/dev/null if [ "$?" == 0 ]; then printf "${SUCCESS}\n" - try_run "Building ... " cargo +stable build --color=always - try_run "Testing ... " cargo +stable test --color=always - try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always - for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}') + try_run "Building ... " cargo build --color=always + try_run "Testing ... " cargo test --color=always + try_run "Testing with all features enabled ... " cargo test --all-features --color=always + for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}') do - try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE + try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE done check_toolchain nightly diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index 823ac5bfd..f7aa3ea6c 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -2,7 +2,7 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | 13 | WorldClient::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 diff --git a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr index b1be874ab..d3f4eb62a 100644 --- a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr +++ b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr @@ -2,7 +2,7 @@ error: unused `tarpc::serde_transport::tcp::Connect` that must be used --> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9 | 7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12 diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr index 28106e63f..d96cda833 100644 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr @@ -1,11 +1,15 @@ error: not all trait items implemented, missing: `HelloFut` - --> $DIR/tarpc_server_missing_async.rs:9:1 - | -9 | impl World for HelloServer { - | ^^^^ + --> tests/compile_fail/tarpc_server_missing_async.rs:9:1 + | +9 | / impl World for HelloServer { +10 | | fn hello(name: String) -> String { +11 | | format!("Hello, {name}!", name) +12 | | } +13 | | } + | |_^ error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> $DIR/tarpc_server_missing_async.rs:10:5 + --> tests/compile_fail/tarpc_server_missing_async.rs:10:5 | 10 | fn hello(name: String) -> String { - | ^^ + | ^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr index 5b5adf0c2..446f224f6 100644 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr +++ b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr @@ -2,7 +2,7 @@ error: unused `TokioChannelExecutor` that must be used --> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9 | 27 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12 diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr index 57daf9063..07d4b5a9b 100644 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr +++ b/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr @@ -2,7 +2,7 @@ error: unused `TokioServerExecutor` that must be used --> tests/compile_fail/tokio/must_use_server_executor.rs:28:9 | 28 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/tokio/must_use_server_executor.rs:26:12 From 30710db2eb2d82dd6a88f30aa36b324abcccec83 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 23 Nov 2022 01:36:51 -0800 Subject: [PATCH 04/30] Use async fn in generated traits!! Add helper fn to server::incoming module for spawning. --- example-service/src/lib.rs | 3 + example-service/src/server.rs | 10 +- plugins/src/lib.rs | 235 ++---------------- plugins/tests/server.rs | 101 +------- plugins/tests/service.rs | 42 ++-- tarpc/Cargo.toml | 3 +- tarpc/examples/compression.rs | 17 +- tarpc/examples/custom_transport.rs | 18 +- tarpc/examples/pubsub.rs | 13 +- tarpc/examples/readme.rs | 20 +- tarpc/examples/tracing.rs | 23 +- tarpc/src/lib.rs | 97 ++------ tarpc/src/server.rs | 145 ++++++++--- tarpc/src/server/incoming.rs | 67 ++++- tarpc/src/server/request_hook/after.rs | 24 +- tarpc/src/server/request_hook/before.rs | 16 +- .../server/request_hook/before_and_after.rs | 27 +- tarpc/src/server/tokio.rs | 129 ---------- tarpc/src/transport/channel.rs | 33 ++- tarpc/tests/compile_fail.rs | 2 - .../compile_fail/must_use_request_dispatch.rs | 3 + .../must_use_request_dispatch.stderr | 8 +- .../tarpc_server_missing_async.rs | 15 -- .../tarpc_server_missing_async.stderr | 15 -- .../tokio/must_use_channel_executor.rs | 29 --- .../tokio/must_use_channel_executor.stderr | 11 - .../tokio/must_use_server_executor.rs | 30 --- .../tokio/must_use_server_executor.stderr | 11 - tarpc/tests/dataservice.rs | 13 +- tarpc/tests/service_functional.rs | 81 +++--- 30 files changed, 412 insertions(+), 829 deletions(-) delete mode 100644 tarpc/src/server/tokio.rs delete mode 100644 tarpc/tests/compile_fail/tarpc_server_missing_async.rs delete mode 100644 tarpc/tests/compile_fail/tarpc_server_missing_async.stderr delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_server_executor.rs delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index bc38fe93e..822d8217b 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use std::env; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index b0281e983..6c78598be 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use clap::Parser; use futures::{future, prelude::*}; use rand::{ @@ -34,7 +37,6 @@ struct Flags { #[derive(Clone)] struct HelloServer(SocketAddr); -#[tarpc::server] impl World for HelloServer { async fn hello(self, _: context::Context, name: String) -> String { let sleep_time = @@ -44,6 +46,10 @@ impl World for HelloServer { } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let flags = Flags::parse(); @@ -66,7 +72,7 @@ async fn main() -> anyhow::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.transport().peer_addr().unwrap()); - channel.execute(server.serve()) + channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. .buffer_unordered(10) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index efab161bb..f33cea09e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -12,18 +12,18 @@ extern crate quote; extern crate syn; use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; +use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ braced, ext::IdentExt, parenthesized, parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, parse_str, + parse_macro_input, parse_quote, spanned::Spanned, token::Comma, - Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool, - MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility, + Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, + Visibility, }; /// Accumulates multiple errors into a result. @@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string())) .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); - let response_fut_name = &format!("{}ResponseFut", ident.unraw()); let derive_serialize = if derive_serde.0 { Some( quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] @@ -274,11 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .collect::>(); ServiceGenerator { - response_fut_name, service_ident: ident, client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), - response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), request_ident: &format_ident!("{}Request", ident), response_ident: &format_ident!("{}Response", ident), @@ -305,138 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .zip(camel_case_fn_names.iter()) .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), - future_types: &camel_case_fn_names - .iter() - .map(|name| parse_str(&format!("{name}Fut")).unwrap()) - .collect::>(), derive_serialize: derive_serialize.as_ref(), } .into_token_stream() .into() } -/// generate an identifier consisting of the method name to CamelCase with -/// Fut appended to it. -fn associated_type_for_rpc(method: &ImplItemMethod) -> String { - snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut" -} - -/// Transforms an async function into a sync one, returning a type declaration -/// for the return type (a future). -fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { - method.sig.asyncness = None; - - // get either the return type or (). - let ret = match &method.sig.output { - ReturnType::Default => quote!(()), - ReturnType::Type(_, ret) => quote!(#ret), - }; - - let fut_name = associated_type_for_rpc(method); - let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span()); - - // generate the updated return signature. - method.sig.output = parse_quote! { - -> ::core::pin::Pin + ::core::marker::Send - >> - }; - - // transform the body of the method into Box::pin(async move { body }). - let block = method.block.clone(); - method.block = parse_quote! [{ - Box::pin(async move - #block - ) - }]; - - // generate and return type declaration for return type. - let t: ImplItemType = parse_quote! { - type #fut_name_ident = ::core::pin::Pin + ::core::marker::Send>>; - }; - - t -} - -#[proc_macro_attribute] -pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream { - let mut item = syn::parse_macro_input!(input as ItemImpl); - let span = item.span(); - - // the generated type declarations - let mut types: Vec = Vec::new(); - let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new(); - let mut found_non_async_types: Vec<&ImplItemType> = Vec::new(); - - for inner in &mut item.items { - match inner { - ImplItem::Method(method) => { - if method.sig.asyncness.is_some() { - // if this function is declared async, transform it into a regular function - let typedecl = transform_method(method); - types.push(typedecl); - } else { - // If it's not async, keep track of all required associated types for better - // error reporting. - expected_non_async_types.push((method, associated_type_for_rpc(method))); - } - } - ImplItem::Type(typedecl) => found_non_async_types.push(typedecl), - _ => {} - } - } - - if let Err(e) = - verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types) - { - return TokenStream::from(e.to_compile_error()); - } - - // add the type declarations into the impl block - for t in types.into_iter() { - item.items.push(syn::ImplItem::Type(t)); - } - - TokenStream::from(quote!(#item)) -} - -fn verify_types_were_provided( - span: Span, - expected: &[(&ImplItemMethod, String)], - provided: &[&ImplItemType], -) -> syn::Result<()> { - let mut result = Ok(()); - for (method, expected) in expected { - if !provided.iter().any(|typedecl| typedecl.ident == expected) { - let mut e = syn::Error::new( - span, - format!("not all trait items implemented, missing: `{expected}`"), - ); - let fn_span = method.sig.fn_token.span(); - e.extend(syn::Error::new( - fn_span.join(method.sig.ident.span()).unwrap_or(fn_span), - format!( - "hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async", - method.sig.ident - ), - )); - match result { - Ok(_) => result = Err(e), - Err(ref mut error) => error.extend(Some(e)), - } - } - } - result -} - // Things needed to generate the service items: trait, serve impl, request/response enums, and // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, client_stub_ident: &'a Ident, server_ident: &'a Ident, - response_fut_ident: &'a Ident, - response_fut_name: &'a str, client_ident: &'a Ident, request_ident: &'a Ident, response_ident: &'a Ident, @@ -444,7 +321,6 @@ struct ServiceGenerator<'a> { attrs: &'a [Attribute], rpcs: &'a [RpcMethod], camel_case_idents: &'a [Ident], - future_types: &'a [Type], method_idents: &'a [&'a Ident], request_names: &'a [String], method_attrs: &'a [&'a [Attribute]], @@ -460,7 +336,6 @@ impl<'a> ServiceGenerator<'a> { attrs, rpcs, vis, - future_types, return_types, service_ident, client_stub_ident, @@ -470,27 +345,19 @@ impl<'a> ServiceGenerator<'a> { .. } = self; - let types_and_fns = rpcs + let rpc_fns = rpcs .iter() - .zip(future_types.iter()) .zip(return_types.iter()) .map( |( - ( - RpcMethod { - attrs, ident, args, .. - }, - future_type, - ), + RpcMethod { + attrs, ident, args, .. + }, output, )| { - let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`]."); quote! { - #[doc = #ty_doc] - type #future_type: std::future::Future; - #( #attrs )* - fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; + async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output; } }, ); @@ -499,7 +366,7 @@ impl<'a> ServiceGenerator<'a> { quote! { #( #attrs )* #vis trait #service_ident: Sized { - #( #types_and_fns )* + #( #rpc_fns )* /// Returns a serving function to use with /// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute). @@ -539,7 +406,6 @@ impl<'a> ServiceGenerator<'a> { server_ident, service_ident, response_ident, - response_fut_ident, camel_case_idents, arg_pats, method_idents, @@ -553,7 +419,6 @@ impl<'a> ServiceGenerator<'a> { { type Req = #request_ident; type Resp = #response_ident; - type Fut = #response_fut_ident; fn method(&self, req: &#request_ident) -> Option<&'static str> { Some(match req { @@ -565,15 +430,16 @@ impl<'a> ServiceGenerator<'a> { }) } - fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { + async fn serve(self, ctx: tarpc::context::Context, req: #request_ident) + -> Result<#response_ident, tarpc::ServerError> { match req { #( #request_ident::#camel_case_idents{ #( #arg_pats ),* } => { - #response_fut_ident::#camel_case_idents( + Ok(#response_ident::#camel_case_idents( #service_ident::#method_idents( self.service, ctx, #( #arg_pats ),* - ) - ) + ).await + )) } )* } @@ -624,74 +490,6 @@ impl<'a> ServiceGenerator<'a> { } } - fn enum_response_future(&self) -> TokenStream2 { - let &Self { - vis, - service_ident, - response_fut_ident, - camel_case_idents, - future_types, - .. - } = self; - - quote! { - /// A future resolving to a server response. - #[allow(missing_docs)] - #vis enum #response_fut_ident { - #( #camel_case_idents(::#future_types) ),* - } - } - } - - fn impl_debug_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_fut_name, - .. - } = self; - - quote! { - impl std::fmt::Debug for #response_fut_ident { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct(#response_fut_name).finish() - } - } - } - } - - fn impl_future_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_ident, - camel_case_idents, - .. - } = self; - - quote! { - impl std::future::Future for #response_fut_ident { - type Output = Result<#response_ident, tarpc::ServerError>; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll> - { - unsafe { - match std::pin::Pin::get_unchecked_mut(self) { - #( - #response_fut_ident::#camel_case_idents(resp) => - std::pin::Pin::new_unchecked(resp) - .poll(cx) - .map(#response_ident::#camel_case_idents) - .map(Ok), - )* - } - } - } - } - } - } - fn struct_client(&self) -> TokenStream2 { let &Self { vis, @@ -804,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.impl_serve_for_server(), self.enum_request(), self.enum_response(), - self.enum_response_future(), - self.impl_debug_for_response_future(), - self.impl_future_for_response_future(), self.struct_client(), self.impl_client_new(), self.impl_client_rpc_methods(), diff --git a/plugins/tests/server.rs b/plugins/tests/server.rs index f0222ffd3..7fcec793e 100644 --- a/plugins/tests/server.rs +++ b/plugins/tests/server.rs @@ -1,7 +1,5 @@ -use assert_type_eq::assert_type_eq; -use futures::Future; -use std::pin::Pin; -use tarpc::context; +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] // these need to be out here rather than inside the function so that the // assert_type_eq macro can pick them up. @@ -12,42 +10,6 @@ trait Foo { async fn baz(); } -#[test] -fn type_generation_works() { - #[tarpc::server] - impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { - (s, i) - } - - async fn bar(self, _: context::Context, s: String) -> String { - s - } - - async fn baz(self, _: context::Context) {} - } - - // the assert_type_eq macro can only be used once per block. - { - assert_type_eq!( - <() as Foo>::TwoPartFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BarFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BazFut, - Pin + Send>> - ); - } -} - #[allow(non_camel_case_types)] #[test] fn raw_idents_work() { @@ -59,24 +21,6 @@ fn raw_idents_work() { async fn r#fn(r#impl: r#yield) -> r#yield; async fn r#async(); } - - #[tarpc::server] - impl r#trait for () { - async fn r#await( - self, - _: context::Context, - r#struct: r#yield, - r#enum: i32, - ) -> (r#yield, i32) { - (r#struct, r#enum) - } - - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { - r#impl - } - - async fn r#async(self, _: context::Context) {} - } } #[test] @@ -100,45 +44,4 @@ fn syntax() { #[doc = "attr"] async fn one_arg_implicit_return_error(one: String); } - - #[tarpc::server] - impl Syntax for () { - #[deny(warnings)] - #[allow(non_snake_case)] - async fn TestCamelCaseDoesntConflict(self, _: context::Context) {} - - async fn hello(self, _: context::Context) -> String { - String::new() - } - - async fn attr(self, _: context::Context, _s: String) -> String { - String::new() - } - - async fn no_args_no_return(self, _: context::Context) {} - - async fn no_args(self, _: context::Context) -> () {} - - async fn one_arg(self, _: context::Context, _one: String) -> i32 { - 0 - } - - async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {} - - async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String { - String::new() - } - - async fn no_args_ret_error(self, _: context::Context) -> i32 { - 0 - } - - async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String { - String::new() - } - - async fn no_arg_implicit_return_error(self, _: context::Context) {} - - async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {} - } } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index b37cbcead..38bd7f0dc 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,9 +1,10 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use tarpc::context; #[test] fn att_service_trait() { - use futures::future::{ready, Ready}; - #[tarpc::service] trait Foo { async fn two_part(s: String, i: i32) -> (String, i32); @@ -12,19 +13,16 @@ fn att_service_trait() { } impl Foo for () { - type TwoPartFut = Ready<(String, i32)>; - fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut { - ready((s, i)) + async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + (s, i) } - type BarFut = Ready; - fn bar(self, _: context::Context, s: String) -> Self::BarFut { - ready(s) + async fn bar(self, _: context::Context, s: String) -> String { + s } - type BazFut = Ready<()>; - fn baz(self, _: context::Context) -> Self::BazFut { - ready(()) + async fn baz(self, _: context::Context) { + () } } } @@ -32,8 +30,6 @@ fn att_service_trait() { #[allow(non_camel_case_types)] #[test] fn raw_idents() { - use futures::future::{ready, Ready}; - type r#yield = String; #[tarpc::service] @@ -44,19 +40,21 @@ fn raw_idents() { } impl r#trait for () { - type AwaitFut = Ready<(r#yield, i32)>; - fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut { - ready((r#struct, r#enum)) + async fn r#await( + self, + _: context::Context, + r#struct: r#yield, + r#enum: i32, + ) -> (r#yield, i32) { + (r#struct, r#enum) } - type FnFut = Ready; - fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut { - ready(r#impl) + async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + r#impl } - type AsyncFut = Ready<()>; - fn r#async(self, _: context::Context) -> Self::AsyncFut { - ready(()) + async fn r#async(self, _: context::Context) { + () } } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 24da20eeb..6442726a2 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -75,7 +75,8 @@ opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] } pin-utils = "0.1.0-alpha" serde_bytes = "0.11" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tokio = { version = "1", features = ["full", "test-util"] } +tokio = { version = "1", features = ["full", "test-util", "tracing"] } +console-subscriber = "0.1" tokio-serde = { version = "0.8", features = ["json", "bincode"] } trybuild = "1.0" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 942fdc8af..cc993f0af 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -1,5 +1,14 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; -use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt}; +use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; @@ -99,13 +108,16 @@ pub trait World { #[derive(Clone, Debug)] struct HelloServer; -#[tarpc::server] impl World for HelloServer { async fn hello(self, _: context::Context, name: String) -> String { format!("Hey, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; @@ -114,6 +126,7 @@ async fn main() -> anyhow::Result<()> { let transport = incoming.next().await.unwrap().unwrap(); BaseChannel::with_defaults(add_compression(transport)) .execute(HelloServer.serve()) + .for_each(spawn) .await; }); diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index e7e2ce3d5..2c5fd4dc4 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -1,3 +1,13 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + +use futures::prelude::*; use tarpc::context::Context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; @@ -13,7 +23,6 @@ pub trait PingService { #[derive(Clone)] struct Service; -#[tarpc::server] impl PingService for Service { async fn ping(self, _: Context) {} } @@ -26,13 +35,18 @@ async fn main() -> anyhow::Result<()> { let listener = UnixListener::bind(bind_addr).unwrap(); let codec_builder = LengthDelimitedCodec::builder(); + async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); + } tokio::spawn(async move { loop { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); + let fut = BaseChannel::with_defaults(transport) + .execute(Service.serve()) + .for_each(spawn); tokio::spawn(fut); } }); diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 8140cb36c..e254b294f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" /// port. Because both publishers and subscribers initiate their connections to the PubSub /// server, the server requires no prior knowledge of either publishers or subscribers. @@ -79,7 +82,6 @@ struct Subscriber { topics: Vec, } -#[tarpc::server] impl subscriber::Subscriber for Subscriber { async fn topics(self, _: context::Context) -> Vec { self.topics.clone() @@ -117,7 +119,8 @@ impl Subscriber { )) } }; - let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); + let (handler, abort_handle) = + future::abortable(handler.execute(subscriber.serve()).for_each(spawn)); tokio::spawn(async move { match handler.await { Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."), @@ -143,6 +146,10 @@ struct PublisherAddrs { subscriptions: SocketAddr, } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + impl Publisher { async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -162,6 +169,7 @@ impl Publisher { server::BaseChannel::with_defaults(publisher) .execute(self.serve()) + .for_each(spawn) .await }); @@ -257,7 +265,6 @@ impl Publisher { } } -#[tarpc::server] impl publisher::Publisher for Publisher { async fn publish(self, _: context::Context, topic: String, message: String) { info!("received message to publish."); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 80792314f..c6ef61eb4 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -4,7 +4,10 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use futures::future::{self, Ready}; +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + +use futures::prelude::*; use tarpc::{ client, context, server::{self, Channel}, @@ -23,22 +26,21 @@ pub trait World { struct HelloServer; impl World for HelloServer { - // Each defined rpc generates two items in the trait, a fn that serves the RPC, and - // an associated type representing the future output by the fn. - - type HelloFut = Ready; - - fn hello(self, _: context::Context, name: String) -> Self::HelloFut { - future::ready(format!("Hello, {name}!")) + async fn hello(self, _: context::Context, name: String) -> String { + format!("Hello, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); - tokio::spawn(server.execute(HelloServer.serve())); + tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // that takes a config and any Transport as input. diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index c2afb3bf1..c0a1f8d33 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -4,7 +4,8 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![feature(type_alias_impl_trait)] +#![allow(incomplete_features)] +#![feature(async_fn_in_trait, type_alias_impl_trait)] use crate::{ add::{Add as AddService, AddStub}, @@ -25,7 +26,10 @@ use tarpc::{ RpcError, }, context, serde_transport, - server::{incoming::Incoming, BaseChannel, Serve}, + server::{ + incoming::{spawn_incoming, Incoming}, + BaseChannel, Serve, + }, tokio_serde::formats::Json, ClientMessage, Response, ServerError, Transport, }; @@ -51,7 +55,6 @@ pub mod double { #[derive(Clone)] struct AddServer; -#[tarpc::server] impl AddService for AddServer { async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { x + y @@ -63,7 +66,6 @@ struct DoubleServer { add_client: add::AddClient, } -#[tarpc::server] impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, @@ -158,9 +160,8 @@ async fn main() -> anyhow::Result<()> { }); let add_server = add_listener1 .chain(add_listener2) - .map(BaseChannel::with_defaults) - .execute(server); - tokio::spawn(add_server); + .map(BaseChannel::with_defaults); + tokio::spawn(spawn_incoming(add_server.execute(server))); let add_client = add::AddClient::from(make_stub([ tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, @@ -171,11 +172,9 @@ async fn main() -> anyhow::Result<()> { .await? .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); - let double_server = double_listener - .map(BaseChannel::with_defaults) - .take(1) - .execute(DoubleServer { add_client }.serve()); - tokio::spawn(double_server); + let double_server = double_listener.map(BaseChannel::with_defaults).take(1); + let server = DoubleServer { add_client }.serve(); + tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let double_client = diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 280da694e..391c6fcaf 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -80,6 +80,8 @@ //! First, let's set up the dependencies and service definition. //! //! ```rust +//! #![allow(incomplete_features)] +//! #![feature(async_fn_in_trait)] //! # extern crate futures; //! //! use futures::{ @@ -104,6 +106,8 @@ //! implement it for our Server struct. //! //! ```rust +//! # #![allow(incomplete_features)] +//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -126,13 +130,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! // an associated type representing the future output by the fn. -//! -//! type HelloFut = Ready; -//! -//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! async fn hello(self, _: context::Context, name: String) -> String { +//! format!("Hello, {name}!") //! } //! } //! ``` @@ -143,6 +143,8 @@ //! available behind the `tcp` feature. //! //! ```rust +//! # #![allow(incomplete_features)] +//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -164,11 +166,9 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! # // an associated type representing the future output by the fn. -//! # type HelloFut = Ready; -//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! # future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # format!("Hello, {name}!") //! # } //! # } //! # #[cfg(not(feature = "tokio1"))] @@ -179,7 +179,12 @@ //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! //! let server = server::BaseChannel::with_defaults(server_transport); -//! tokio::spawn(server.execute(HelloServer.serve())); +//! tokio::spawn( +//! server.execute(HelloServer.serve()) +//! // Handle all requests concurrently. +//! .for_each(|response| async move { +//! tokio::spawn(response); +//! })); //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. @@ -200,7 +205,14 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -#![feature(type_alias_impl_trait)] +// For async_fn_in_trait +#![allow(incomplete_features)] +#![feature( + iter_intersperse, + type_alias_impl_trait, + async_fn_in_trait, + return_position_impl_trait_in_trait +)] #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -226,6 +238,7 @@ pub use tarpc_plugins::derive_serde; /// Rpc methods are specified, mirroring trait syntax: /// /// ``` +/// #![feature(async_fn_in_trait)] /// #[tarpc::service] /// trait Service { /// /// Say hello @@ -245,62 +258,6 @@ pub use tarpc_plugins::derive_serde; /// * `fn new_stub` -- creates a new Client stub. pub use tarpc_plugins::service; -/// A utility macro that can be used for RPC server implementations. -/// -/// Syntactic sugar to make using async functions in the server implementation -/// easier. It does this by rewriting code like this, which would normally not -/// compile because async functions are disallowed in trait implementations: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::net::SocketAddr; -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::server] -/// impl World for HelloServer { -/// async fn hello(self, _: context::Context, name: String) -> String { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// } -/// } -/// ``` -/// -/// Into code like this, which matches the service trait definition: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::pin::Pin; -/// # use futures::Future; -/// # use std::net::SocketAddr; -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// impl World for HelloServer { -/// type HelloFut = Pin + Send>>; -/// -/// fn hello(self, _: context::Context, name: String) -> Pin -/// + Send>> { -/// Box::pin(async move { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// }) -/// } -/// } -/// ``` -/// -/// Note that this won't touch functions unless they have been annotated with -/// `async`, meaning that this should not break existing code. -pub use tarpc_plugins::server; - pub(crate) mod cancellations; pub mod client; pub mod context; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index f5c5151ee..b9c95f068 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -42,11 +42,6 @@ pub mod limits; /// Provides helper methods for streams of Channels. pub mod incoming; -/// Provides convenience functionality for tokio-enabled applications. -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub mod tokio; - use request_hook::{ AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, }; @@ -86,11 +81,8 @@ pub trait Serve { /// Type of response. type Resp; - /// Type of response future. - type Fut: Future>; - /// Responds to a single request. - fn serve(self, ctx: context::Context, req: Self::Req) -> Self::Fut; + async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; /// Extracts a method name from the request. fn method(&self, _request: &Self::Req) -> Option<&'static str> { @@ -281,10 +273,9 @@ where { type Req = Req; type Resp = Resp; - type Fut = Fut; - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { - (self.f)(ctx, req) + async fn serve(self, ctx: context::Context, req: Req) -> Result { + (self.f)(ctx, req).await } } @@ -539,34 +530,42 @@ where } } - /// Runs the channel until completion by executing all requests using the given service - /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's - /// default executor. + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. /// /// # Example /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; /// use futures::prelude::*; + /// use tracing_subscriber::prelude::*; + /// + /// #[derive(PartialEq, Eq, Debug)] + /// struct MyInt(i32); /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { /// let (tx, rx) = transport::channel::unbounded(); /// let client = client::new(client::Config::default(), tx).spawn(); - /// let channel = BaseChannel::new(server::Config::default(), rx); - /// tokio::spawn(channel.execute(serve(|_, i| async move { Ok(i + 1) }))); - /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// let channel = BaseChannel::with_defaults(rx); + /// tokio::spawn( + /// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!( + /// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(), + /// MyInt(2)); /// } /// ``` - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> + fn execute(self, serve: S) -> impl Stream> where Self: Sized, - S: Serve + Send + 'static, - S::Fut: Send, - Self::Req: Send + 'static, - Self::Resp: Send + 'static, + S: Serve + Clone, { self.requests().execute(serve) } @@ -579,10 +578,10 @@ where E: Error + Send + Sync + 'static, { /// An error occurred reading from, or writing to, the transport. - #[error("an error occurred in the transport: {0}")] + #[error("an error occurred in the transport")] Transport(#[source] E), /// An error occurred while polling expired requests. - #[error("an error occurred while polling expired requests: {0}")] + #[error("an error occurred while polling expired requests")] Timer(#[source] ::tokio::time::error::Error), } @@ -674,15 +673,17 @@ where Poll::Pending => Pending, }; + let status = cancellation_status + .combine(expiration_status) + .combine(request_status); + tracing::trace!( - "Expired requests: {:?}, Inbound: {:?}", - expiration_status, - request_status + "Cancellations: {cancellation_status:?}, \ + Expired requests: {expiration_status:?}, \ + Inbound: {request_status:?}, \ + Overall: {status:?}", ); - match cancellation_status - .combine(expiration_status) - .combine(request_status) - { + match status { Ready => continue, Closed => return Poll::Ready(None), Pending => return Poll::Pending, @@ -890,6 +891,51 @@ where } Poll::Ready(Some(Ok(()))) } + + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. + /// + /// If the channel encounters an error, the stream is terminated and the error is logged. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// tokio::spawn( + /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + pub fn execute(self, serve: S) -> impl Stream> + where + S: Serve + Clone, + { + self.take_while(|result| { + if let Err(e) = result { + tracing::warn!("Requests stream errored out: {}", e); + } + futures::future::ready(result.is_ok()) + }) + .filter_map(|result| async move { result.ok() }) + .map(move |request| { + let serve = serve.clone(); + request.execute(serve) + }) + } } impl fmt::Debug for Requests @@ -1021,6 +1067,13 @@ impl InFlightRequest { } } +fn print_err(e: &(dyn Error + 'static)) -> String { + anyhow::Chain::new(e) + .map(|e| e.to_string()) + .intersperse(": ".into()) + .collect::() +} + impl Stream for Requests where C: Channel, @@ -1029,17 +1082,33 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - let read = self.as_mut().pump_read(cx)?; + let read = self.as_mut().pump_read(cx).map_err(|e| { + tracing::trace!("read: {}", print_err(&e)); + e + })?; let read_closed = matches!(read, Poll::Ready(None)); - match (read, self.as_mut().pump_write(cx, read_closed)?) { + let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| { + tracing::trace!("write: {}", print_err(&e)); + e + })?; + match (read, write) { (Poll::Ready(None), Poll::Ready(None)) => { + tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)"); return Poll::Ready(None); } (Poll::Ready(Some(request_handler)), _) => { + tracing::trace!("read: Poll::Ready(Some), write: _"); return Poll::Ready(Some(Ok(request_handler))); } - (_, Poll::Ready(Some(()))) => {} - _ => { + (_, Poll::Ready(Some(()))) => { + tracing::trace!("read: _, write: Poll::Ready(Some)"); + } + (read @ Poll::Pending, write) | (read, write @ Poll::Pending) => { + tracing::trace!( + "read pending: {}, write pending: {}", + read.is_pending(), + write.is_pending() + ); return Poll::Pending; } } diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 931e87669..9195ee301 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -1,13 +1,10 @@ use super::{ limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, - Channel, + Channel, Serve, }; use futures::prelude::*; use std::{fmt, hash::Hash}; -#[cfg(feature = "tokio1")] -use super::{tokio::TokioServerExecutor, Serve}; - /// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel). pub trait Incoming where @@ -28,16 +25,62 @@ where MaxRequestsPerChannel::new(self, n) } - /// [Executes](Channel::execute) each incoming channel. Each channel will be handled - /// concurrently by spawning on tokio's default executor, and each request will be also - /// be spawned on tokio's default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioServerExecutor + /// Returns a stream of channels in execution. Each channel in execution is a stream of + /// futures, where each future is an in-flight request being rsponded to. + fn execute( + self, + serve: S, + ) -> impl Stream>> where - S: Serve, + S: Serve + Clone, { - TokioServerExecutor::new(self, serve) + self.map(move |channel| channel.execute(serve.clone())) + } +} + +#[cfg(feature = "tokio1")] +/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion. +/// Each channel is spawned, and each request from each channel is spawned. +/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended +/// for spawning streams of channels. +/// +/// # Example +/// ```rust +/// use tarpc::{ +/// context, +/// client::{self, NewClient}, +/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, +/// transport, +/// }; +/// use futures::prelude::*; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = transport::channel::unbounded(); +/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); +/// tokio::spawn(dispatch); +/// +/// let incoming = stream::once(async move { +/// BaseChannel::new(server::Config::default(), rx) +/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// tokio::spawn(spawn_incoming(incoming)); +/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); +/// } +/// ``` +pub async fn spawn_incoming( + incoming: impl Stream< + Item = impl Stream + Send + 'static> + Send + 'static, + >, +) { + use futures::pin_mut; + pin_mut!(incoming); + while let Some(channel) = incoming.next().await { + tokio::spawn(async move { + pin_mut!(channel); + while let Some(request) = channel.next().await { + tokio::spawn(request); + } + }); } } diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index a3803bade..4fd48dd4b 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -71,19 +71,17 @@ where { type Req = Serv::Req; type Resp = Serv::Resp; - type Fut = AfterRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Serv::Req) -> Self::Fut { - async move { - let AfterRequestHook { - serve, mut hook, .. - } = self; - let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; - resp - } + async fn serve( + self, + mut ctx: context::Context, + req: Serv::Req, + ) -> Result { + let AfterRequestHook { + serve, mut hook, .. + } = self; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp } } - -type AfterRequestHookFut> = - impl Future>; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 38ad54d01..2c478dbb1 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -67,18 +67,16 @@ where { type Req = Serv::Req; type Resp = Serv::Resp; - type Fut = BeforeRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Self::Req) -> Self::Fut { + async fn serve( + self, + mut ctx: context::Context, + req: Self::Req, + ) -> Result { let BeforeRequestHook { serve, mut hook, .. } = self; - async move { - hook.before(&mut ctx, &req).await?; - serve.serve(ctx, req).await - } + hook.before(&mut ctx, &req).await?; + serve.serve(ctx, req).await } } - -type BeforeRequestHookFut> = - impl Future>; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index ca42460bc..ff61a53ea 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -8,7 +8,6 @@ use super::{after::AfterRequest, before::BeforeRequest}; use crate::{context, server::Serve, ServerError}; -use futures::prelude::*; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. @@ -47,24 +46,14 @@ where { type Req = Req; type Resp = Resp; - type Fut = BeforeAndAfterRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Req) -> Self::Fut { - async move { - let BeforeAndAfterRequestHook { - serve, mut hook, .. - } = self; - hook.before(&mut ctx, &req).await?; - let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; - resp - } + async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + let BeforeAndAfterRequestHook { + serve, mut hook, .. + } = self; + hook.before(&mut ctx, &req).await?; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp } } - -type BeforeAndAfterRequestHookFut< - Req, - Resp, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, -> = impl Future>; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs deleted file mode 100644 index e9ad84221..000000000 --- a/tarpc/src/server/tokio.rs +++ /dev/null @@ -1,129 +0,0 @@ -use super::{Channel, Requests, Serve}; -use futures::{prelude::*, ready, task::*}; -use pin_project::pin_project; -use std::pin::Pin; - -/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) -/// for each new channel. Returned by -/// [`Incoming::execute`](crate::server::incoming::Incoming::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioServerExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - pub(crate) fn new(inner: T, serve: S) -> Self { - Self { inner, serve } - } -} - -/// A future that drives the server by [spawning](tokio::spawn) each [response -/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by -/// [`Channel::execute`](crate::server::Channel::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioChannelExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -impl TokioChannelExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -// Send + 'static execution helper methods. - -impl Requests -where - C: Channel, - C::Req: Send + 'static, - C::Resp: Send + 'static, -{ - /// Executes all requests using the given service function. Requests are handled concurrently - /// by [spawning](::tokio::spawn) each handler on tokio's default executor. - /// - /// # Example - /// - /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; - /// use futures::prelude::*; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); - /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); - /// let client = client::new(client::Config::default(), tx).spawn(); - /// tokio::spawn(requests.execute(serve(|_, i| async move { Ok(i + 1) }))); - /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); - /// } - /// ``` - pub fn execute(self, serve: S) -> TokioChannelExecutor - where - S: Serve + Send + 'static, - { - TokioChannelExecutor { inner: self, serve } - } -} - -impl Future for TokioServerExecutor -where - St: Sized + Stream, - C: Channel + Send + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, - Se::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { - tokio::spawn(channel.execute(self.serve.clone())); - } - tracing::info!("Server shutting down."); - Poll::Ready(()) - } -} - -impl Future for TokioChannelExecutor, S> -where - C: Channel + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, - S::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { - match response_handler { - Ok(resp) => { - let server = self.serve.clone(); - tokio::spawn(async move { - resp.execute(server).await; - }); - } - Err(e) => { - tracing::warn!("Requests stream errored out: {}", e); - break; - } - } - } - Poll::Ready(()) - } -} diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 7f3035d14..98ea0aac7 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -14,9 +14,15 @@ use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] pub enum ChannelError { - /// An error occurred sending over the channel. - #[error("an error occurred sending over the channel")] + /// An error occurred readying to send into the channel. + #[error("an error occurred readying to send into the channel")] + Ready(#[source] Box), + /// An error occurred sending into the channel. + #[error("an error occurred sending into the channel")] Send(#[source] Box), + /// An error occurred receiving from the channel. + #[error("an error occurred receiving from the channel")] + Receive(#[source] Box), } /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's @@ -48,7 +54,10 @@ impl Stream for UnboundedChannel { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.rx.poll_recv(cx).map(|option| option.map(Ok)) + self.rx + .poll_recv(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -59,7 +68,7 @@ impl Sink for UnboundedChannel { fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(if self.tx.is_closed() { - Err(ChannelError::Send(CLOSED_MESSAGE.into())) + Err(ChannelError::Ready(CLOSED_MESSAGE.into())) } else { Ok(()) }) @@ -110,7 +119,11 @@ impl Stream for Channel { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.project().rx.poll_next(cx).map(|option| option.map(Ok)) + self.project() + .rx + .poll_next(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -121,7 +134,7 @@ impl Sink for Channel { self.project() .tx .poll_ready(cx) - .map_err(|e| ChannelError::Send(Box::new(e))) + .map_err(|e| ChannelError::Ready(Box::new(e))) } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { @@ -146,8 +159,7 @@ impl Sink for Channel { } } -#[cfg(test)] -#[cfg(feature = "tokio1")] +#[cfg(all(test, feature = "tokio1"))] mod tests { use crate::{ client::{self, RpcError}, @@ -186,7 +198,10 @@ mod tests { format!("{request:?} is not an int"), ) }) - })), + })) + .for_each(|channel| async move { + tokio::spawn(channel.for_each(|response| response)); + }), ); let client = client::new(client::Config::default(), client_channel).spawn(); diff --git a/tarpc/tests/compile_fail.rs b/tarpc/tests/compile_fail.rs index 4c5a28ec9..c28fe2fa1 100644 --- a/tarpc/tests/compile_fail.rs +++ b/tarpc/tests/compile_fail.rs @@ -2,8 +2,6 @@ fn ui() { let t = trybuild::TestCases::new(); t.compile_fail("tests/compile_fail/*.rs"); - #[cfg(feature = "tokio1")] - t.compile_fail("tests/compile_fail/tokio/*.rs"); #[cfg(all(feature = "serde-transport", feature = "tcp"))] t.compile_fail("tests/compile_fail/serde_transport/*.rs"); } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d3237..18cda0d90 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,3 +1,6 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use tarpc::client; #[tarpc::service] diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index f7aa3ea6c..d12912a86 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,11 +1,11 @@ error: unused `RequestDispatch` that must be used - --> tests/compile_fail/must_use_request_dispatch.rs:13:9 + --> tests/compile_fail/must_use_request_dispatch.rs:16:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; +16 | WorldClient::new(client::Config::default(), client_transport).dispatch; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here - --> tests/compile_fail/must_use_request_dispatch.rs:11:12 + --> tests/compile_fail/must_use_request_dispatch.rs:14:12 | -11 | #[deny(unused_must_use)] +14 | #[deny(unused_must_use)] | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs b/tarpc/tests/compile_fail/tarpc_server_missing_async.rs deleted file mode 100644 index 99d858b6d..000000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs +++ /dev/null @@ -1,15 +0,0 @@ -#[tarpc::service(derive_serde = false)] -trait World { - async fn hello(name: String) -> String; -} - -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - fn hello(name: String) -> String { - format!("Hello, {name}!", name) - } -} - -fn main() {} diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr deleted file mode 100644 index d96cda833..000000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ /dev/null @@ -1,15 +0,0 @@ -error: not all trait items implemented, missing: `HelloFut` - --> tests/compile_fail/tarpc_server_missing_async.rs:9:1 - | -9 | / impl World for HelloServer { -10 | | fn hello(name: String) -> String { -11 | | format!("Hello, {name}!", name) -12 | | } -13 | | } - | |_^ - -error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> tests/compile_fail/tarpc_server_missing_async.rs:10:5 - | -10 | fn hello(name: String) -> String { - | ^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs deleted file mode 100644 index 6fc2f2bf3..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs +++ /dev/null @@ -1,29 +0,0 @@ -use tarpc::{ - context, - server::{self, Channel}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = server::BaseChannel::with_defaults(server_transport); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr deleted file mode 100644 index 446f224f6..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: unused `TokioChannelExecutor` that must be used - --> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9 - | -27 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12 - | -25 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs deleted file mode 100644 index 950cf74e6..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs +++ /dev/null @@ -1,30 +0,0 @@ -use futures::stream::once; -use tarpc::{ - context, - server::{self, incoming::Incoming}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = once(async move { server::BaseChannel::with_defaults(server_transport) }); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr deleted file mode 100644 index 07d4b5a9b..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: unused `TokioServerExecutor` that must be used - --> tests/compile_fail/tokio/must_use_server_executor.rs:28:9 - | -28 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_server_executor.rs:26:12 - | -26 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 365594bd4..7cd3cb8c7 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,3 +1,6 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use futures::prelude::*; use tarpc::serde_transport; use tarpc::{ @@ -21,7 +24,6 @@ pub trait ColorProtocol { #[derive(Clone)] struct ColorServer; -#[tarpc::server] impl ColorProtocol for ColorServer { async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { match color { @@ -31,6 +33,11 @@ impl ColorProtocol for ColorServer { } } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn test_call() -> anyhow::Result<()> { let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; @@ -40,7 +47,9 @@ async fn test_call() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(ColorServer.serve()), + .execute(ColorServer.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 50d19b0e9..9041aae73 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -1,13 +1,16 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use assert_matches::assert_matches; use futures::{ - future::{join_all, ready, Ready}, + future::{join_all, ready}, prelude::*, }; use std::time::{Duration, SystemTime}; use tarpc::{ client::{self}, context, - server::{self, incoming::Incoming, BaseChannel, Channel}, + server::{incoming::Incoming, BaseChannel, Channel}, transport::channel, }; use tokio::join; @@ -22,39 +25,29 @@ trait Service { struct Server; impl Service for Server { - type AddFut = Ready; - - fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut { - ready(x + y) + async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + x + y } - type HeyFut = Ready; - - fn hey(self, _: context::Context, name: String) -> Self::HeyFut { - ready(format!("Hey, {name}.")) + async fn hey(self, _: context::Context, name: String) -> String { + format!("Hey, {name}.") } } #[tokio::test] -async fn sequential() -> anyhow::Result<()> { - let _ = tracing_subscriber::fmt::try_init(); - - let (tx, rx) = channel::unbounded(); - +async fn sequential() { + let (tx, rx) = tarpc::transport::channel::unbounded(); + let client = client::new(client::Config::default(), tx).spawn(); + let channel = BaseChannel::with_defaults(rx); tokio::spawn( - BaseChannel::new(server::Config::default(), rx) - .requests() - .execute(Server.serve()), + channel + .execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) })) + .for_each(|response| response), + ); + assert_eq!( + client.call(context::current(), "AddOne", 1).await.unwrap(), + 2 ); - - let client = ServiceClient::new(client::Config::default(), tx).spawn(); - - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); - assert_matches!( - client.hey(context::current(), "Tim".into()).await, - Ok(ref s) if s == "Hey, Tim."); - - Ok(()) } #[tokio::test] @@ -70,7 +63,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { #[derive(Debug)] struct AllHandlersComplete; - #[tarpc::server] impl Loop for LoopServer { async fn r#loop(self, _: context::Context) { loop { @@ -121,7 +113,9 @@ async fn serde_tcp() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; @@ -151,7 +145,9 @@ async fn serde_uds() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; @@ -175,7 +171,9 @@ async fn concurrent() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -199,7 +197,9 @@ async fn concurrent_join() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -216,15 +216,20 @@ async fn concurrent_join() -> anyhow::Result<()> { Ok(()) } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( - stream::once(ready(rx)) - .map(BaseChannel::with_defaults) - .execute(Server.serve()), + BaseChannel::with_defaults(rx) + .execute(Server.serve()) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -249,11 +254,9 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type CountFut = futures::future::Ready; - - fn count(self, _: context::Context) -> Self::CountFut { + async fn count(self, _: context::Context) -> u32 { self.0 += 1; - futures::future::ready(self.0) + self.0 } } From e23bba714aafda8fbf87ae09fadb850f32318d75 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 23 Nov 2022 16:39:29 -0800 Subject: [PATCH 05/30] Replace actions-rs --- .github/workflows/main.yml | 76 +++++++++----------------------------- 1 file changed, 17 insertions(+), 59 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1ca595603..198475a23 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,20 +18,11 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@nightly with: - profile: minimal - toolchain: nightly - target: mipsel-unknown-linux-gnu - override: true - - uses: actions-rs/cargo@v1 - with: - command: check - args: --all-features - - uses: actions-rs/cargo@v1 - with: - command: check - args: --all-features --target mipsel-unknown-linux-gnu + targets: mipsel-unknown-linux-gnu + - run: cargo check --all-features + - run: cargo check --all-features --target mipsel-unknown-linux-gnu test: name: Test Suite @@ -42,34 +33,13 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true - - uses: actions-rs/cargo@v1 - with: - command: test - - uses: actions-rs/cargo@v1 - with: - command: test - args: --manifest-path tarpc/Cargo.toml --features serde1 - - uses: actions-rs/cargo@v1 - with: - command: test - args: --manifest-path tarpc/Cargo.toml --features tokio1 - - uses: actions-rs/cargo@v1 - with: - command: test - args: --manifest-path tarpc/Cargo.toml --features serde-transport - - uses: actions-rs/cargo@v1 - with: - command: test - args: --manifest-path tarpc/Cargo.toml --features tcp - - uses: actions-rs/cargo@v1 - with: - command: test - args: --all-features + - uses: dtolnay/rust-toolchain@nightly + - run: cargo test + - run: cargo test --manifest-path tarpc/Cargo.toml --features serde1 + - run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1 + - run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport + - run: cargo test --manifest-path tarpc/Cargo.toml --features tcp + - run: cargo test --all-features fmt: name: Rustfmt @@ -80,16 +50,10 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@nightly with: - profile: minimal - toolchain: nightly - override: true - - run: rustup component add rustfmt - - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check + components: rustfmt + - run: cargo fmt --all -- --check clippy: name: Clippy @@ -100,13 +64,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true - - run: rustup component add clippy - - uses: actions-rs/cargo@v1 + - uses: dtolnay/rust-toolchain@nightly with: - command: clippy - args: --all-features -- -D warnings + components: clippy + - run: cargo clippy --all-features -- -D warnings From ff26d02925a4dbe61fcf645ea67a58235610dbf1 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 23 Nov 2022 17:45:43 -0800 Subject: [PATCH 06/30] Remove bad mem::forget usage. mem::forget is a dangerous tool, and it was being used carelessly for things that have safer alternatives. There was at least one bug where a cloned tokio::sync::mpsc::UnboundedSender used for request cancellation was being leaked on every successful server response, so its refcounts were never decremented. Because these are atomic refcounts, they'll wrap around rather than overflow when reaching the maximum value, so I don't believe this could lead to panics or unsoundness. --- tarpc/src/client.rs | 13 ++++++++++--- tarpc/src/server.rs | 31 +++++++++++++++---------------- tarpc/src/server/testing.rs | 7 ++++--- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 3d4b1ed6a..24e678952 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -19,7 +19,7 @@ use pin_project::pin_project; use std::{ convert::TryFrom, error::Error, - fmt, mem, + fmt, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, @@ -148,6 +148,7 @@ impl Channel { response: &mut response, request_id, cancellation: &self.cancellation, + cancel: true, }; self.to_dispatch .send(DispatchRequest { @@ -169,6 +170,7 @@ struct ResponseGuard<'a, Resp> { response: &'a mut oneshot::Receiver, DeadlineExceededError>>, cancellation: &'a RequestCancellation, request_id: u64, + cancel: bool, } /// An error that can occur in the processing of an RPC. This is not request-specific errors but @@ -197,7 +199,7 @@ impl ResponseGuard<'_, Resp> { async fn response(mut self) -> Result { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. - mem::forget(self); + self.cancel = false; match response { Ok(resp) => Ok(resp?.message?), Err(oneshot::error::RecvError { .. }) => { @@ -224,7 +226,9 @@ impl Drop for ResponseGuard<'_, Resp> { // dispatch task misses an early-arriving cancellation message, then it will see the // receiver as closed. self.response.close(); - self.cancellation.cancel(self.request_id); + if self.cancel { + self.cancellation.cancel(self.request_id); + } } } @@ -655,6 +659,7 @@ mod tests { response: &mut response, cancellation: &cancellation, request_id: 3, + cancel: true, }); // resp's drop() is run, which should send a cancel message. let cx = &mut Context::from_waker(noop_waker_ref()); @@ -675,6 +680,7 @@ mod tests { response: &mut response, cancellation: &cancellation, request_id: 3, + cancel: true, } .response() .await @@ -831,6 +837,7 @@ mod tests { response, cancellation: &channel.cancellation, request_id, + cancel: true, }; channel.to_dispatch.send(request).await.unwrap(); response_guard diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index b9c95f068..d13fb279b 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -21,14 +21,7 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{ - convert::TryFrom, - error::Error, - fmt, - marker::PhantomData, - mem::{self, ManuallyDrop}, - pin::Pin, -}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; @@ -384,10 +377,11 @@ where Ok(TrackedRequest { abort_registration, span, - response_guard: ManuallyDrop::new(ResponseGuard { + response_guard: ResponseGuard { request_id: request.id, request_cancellation: self.request_cancellation.clone(), - }), + cancel: false, + }, request, }) } @@ -416,7 +410,7 @@ pub struct TrackedRequest { /// A span representing the server processing of this request. pub span: Span, /// An inert response guard. Becomes active in an InFlightRequest. - pub response_guard: ManuallyDrop, + pub response_guard: ResponseGuard, } /// The server end of an open connection with a client, receiving requests from, and sending @@ -811,17 +805,19 @@ where request, abort_registration, span, - response_guard, + mut response_guard, }| { { let _entered = span.enter(); tracing::info!("BeginRequest"); } + // The response guard becomes active once in an InFlightRequest. + response_guard.cancel = true; InFlightRequest { request, abort_registration, span, - response_guard: ManuallyDrop::into_inner(response_guard), + response_guard, response_tx: self.responses_tx.clone(), } }, @@ -953,11 +949,14 @@ where pub struct ResponseGuard { request_cancellation: RequestCancellation, request_id: u64, + cancel: bool, } impl Drop for ResponseGuard { fn drop(&mut self) { - self.request_cancellation.cancel(self.request_id); + if self.cancel { + self.request_cancellation.cancel(self.request_id); + } } } @@ -1030,7 +1029,7 @@ impl InFlightRequest { { let Self { response_tx, - response_guard, + mut response_guard, abort_registration, span, request: @@ -1063,7 +1062,7 @@ impl InFlightRequest { // Request processing has completed, meaning either the channel canceled the request or // a request was sent back to the channel. Either way, the channel will clean up the // request data, so the request does not need to be canceled. - mem::forget(response_guard); + response_guard.cancel = false; } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 15e187b3c..938865c0f 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -12,7 +12,7 @@ use crate::{ }; use futures::{task::*, Sink, Stream}; use pin_project::pin_project; -use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime}; +use std::{collections::VecDeque, io, pin::Pin, time::SystemTime}; use tracing::Span; #[pin_project] @@ -101,10 +101,11 @@ impl FakeChannel>, Response> { }, abort_registration, span: Span::none(), - response_guard: ManuallyDrop::new(ResponseGuard { + response_guard: ResponseGuard { request_cancellation, request_id: id, - }), + cancel: false, + }, })); } } From 3f47cce3b9adb63d02da2f93872525b6fe372fd9 Mon Sep 17 00:00:00 2001 From: Akos Vandra-Meyer Date: Fri, 10 Feb 2023 18:31:33 +0100 Subject: [PATCH 07/30] Add ability to create a BaseChannel with a Transport dealing with RequestContexts. To do this, create a Transport with a Sink/Stream of (C, Item/SinkItem). C created in the stream will be opaqualy sent back when sinking the response on the server side. --- tarpc/src/server.rs | 120 +++++++++++++------------ tarpc/src/server/in_flight_requests.rs | 30 +++++-- 2 files changed, 85 insertions(+), 65 deletions(-) diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d13fb279b..3f71d4e24 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -58,9 +58,9 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport<(C, Response), (C, ClientMessage)>, { BaseChannel::new(self, transport) } @@ -283,7 +283,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -294,14 +294,14 @@ pub struct BaseChannel { /// Notifies `canceled_requests` when a request is canceled. request_cancellation: RequestCancellation, /// Holds data necessary to clean up in-flight requests. - in_flight_requests: InFlightRequests, + in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, + ghost: PhantomData<(fn() -> Req, fn(Resp), C)>, } -impl BaseChannel +impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport<(C, Response), (C, ClientMessage)>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -331,7 +331,7 @@ where self.project().transport.get_pin_mut() } - fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -347,8 +347,9 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, + request: (C, Request), ) -> Result, AlreadyExistsError> { + let (context, mut request) = request; let span = info_span!( "RPC", rpc.trace_id = %request.context.trace_id(), @@ -369,6 +370,7 @@ where let start = self.in_flight_requests_mut().start_request( request.id, request.context.deadline, + context, span.clone(), ); match start { @@ -393,7 +395,7 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -413,6 +415,20 @@ pub struct TrackedRequest { pub response_guard: ResponseGuard, } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError +where + E: Error + Send + Sync + 'static, +{ + /// An error occurred reading from, or writing to, the transport. + #[error("an error occurred in the transport")] + Transport(#[source] E), + /// An error occurred while polling expired requests. + #[error("an error occurred while polling expired requests")] + Timer(#[source] ::tokio::time::error::Error), +} + /// The server end of an open connection with a client, receiving requests from, and sending /// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management. /// @@ -439,8 +455,8 @@ pub struct TrackedRequest { /// `TrackedRequest` is to get one from another `Channel`. Ultimately, all `TrackedRequests` are /// created by [`BaseChannel`]. pub trait Channel -where - Self: Transport::Resp>, TrackedRequest<::Req>>, + where + Self: Transport::Resp>, TrackedRequest<::Req>>, { /// Type of request item. type Req; @@ -471,8 +487,8 @@ where self, limit: usize, ) -> limits::requests_per_channel::MaxRequests - where - Self: Sized, + where + Self: Sized, { limits::requests_per_channel::MaxRequests::new(self, limit) } @@ -512,8 +528,8 @@ where /// } /// ``` fn requests(self) -> Requests - where - Self: Sized, + where + Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); @@ -557,31 +573,17 @@ where /// } /// ``` fn execute(self, serve: S) -> impl Stream> - where - Self: Sized, - S: Serve + Clone, + where + Self: Sized, + S: Serve + Clone, { self.requests().execute(serve) } } -/// Critical errors that result in a Channel disconnecting. -#[derive(thiserror::Error, Debug)] -pub enum ChannelError -where - E: Error + Send + Sync + 'static, -{ - /// An error occurred reading from, or writing to, the transport. - #[error("an error occurred in the transport")] - Transport(#[source] E), - /// An error occurred while polling expired requests. - #[error("an error occurred while polling expired requests")] - Timer(#[source] ::tokio::time::error::Error), -} - -impl Stream for BaseChannel +impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport<(C, Response), (C, ClientMessage)>, { type Item = Result, ChannelError>; @@ -609,7 +611,7 @@ where loop { let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { Poll::Ready(Some(request_id)) => { - if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) { + if let Some((ctx, span)) = self.in_flight_requests_mut().remove_request(request_id) { let _entered = span.enter(); tracing::info!("ResponseCancelled"); } @@ -638,8 +640,8 @@ where .map_err(ChannelError::Transport)? { Poll::Ready(Some(message)) => match message { - ClientMessage::Request(request) => { - match self.as_mut().start_request(request) { + (ctx, ClientMessage::Request(request)) => { + match self.as_mut().start_request((ctx,request)) { Ok(request) => return Poll::Ready(Some(Ok(request))), Err(AlreadyExistsError) => { // Instead of closing the channel if a duplicate request is sent, @@ -650,10 +652,10 @@ where } } } - ClientMessage::Cancel { + (_ctx, ClientMessage::Cancel { trace_context, request_id, - } => { + }) => { if !self.in_flight_requests_mut().cancel_request(request_id) { tracing::trace!( rpc.trace_id = %trace_context.trace_id, @@ -686,9 +688,9 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport<(C, Response), (C, ClientMessage)>, T::Error: Error, { type Error = ChannelError; @@ -701,7 +703,7 @@ where } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { - if let Some(span) = self + if let Some((ctx, span)) = self .in_flight_requests_mut() .remove_request(response.request_id) { @@ -709,7 +711,7 @@ where tracing::info!("SendResponse"); self.project() .transport - .start_send(response) + .start_send((ctx, response)) .map_err(ChannelError::Transport) } else { // If the request isn't tracked anymore, there's no need to send the response. @@ -733,20 +735,21 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +implChannel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport<(C, Response), (C, ClientMessage)>, { type Req = Req; type Resp = Resp; type Transport = T; + fn config(&self) -> &Config { &self.config } @@ -770,9 +773,9 @@ where #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver<((), Response)>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender<((), Response)>, } impl Requests @@ -792,14 +795,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver<((), Response)> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, @@ -815,6 +818,7 @@ where response_guard.cancel = true; InFlightRequest { request, + transport_ctx: (), abort_registration, span, response_guard, @@ -868,7 +872,7 @@ where ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { - Some(response) => Poll::Ready(Some(Ok(response))), + Some((c, response)) => Poll::Ready(Some(Ok(response))), None => { // This branch likely won't happen, since the Requests stream is holding a Sender. Poll::Ready(None) @@ -965,15 +969,16 @@ impl Drop for ResponseGuard { /// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will /// be sent to the Channel to clean up associated request state. #[derive(Debug)] -pub struct InFlightRequest { +pub struct InFlightRequest { request: Request, + transport_ctx: C, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender<(C, Response)>, } -impl InFlightRequest { +impl InFlightRequest { /// Returns a reference to the request. pub fn get(&self) -> &Request { &self.request @@ -1029,6 +1034,7 @@ impl InFlightRequest { { let Self { response_tx, + transport_ctx, mut response_guard, abort_registration, span, @@ -1052,7 +1058,7 @@ impl InFlightRequest { request_id, message, }; - let _ = response_tx.send(response).await; + let _ = response_tx.send((transport_ctx, response)).await; tracing::info!("BufferResponse"); }, abort_registration, @@ -1077,7 +1083,7 @@ impl Stream for Requests where C: Channel, { - type Item = Result, C::Error>; + type Item = Result, C::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index 1f8815f40..c101e1456 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -11,21 +11,32 @@ use tracing::Span; /// A data structure that tracks in-flight requests. It aborts requests, /// either on demand or when a request deadline expires. -#[derive(Debug, Default)] -pub struct InFlightRequests { - request_data: FnvHashMap, +#[derive(Debug)] +pub struct InFlightRequests { + request_data: FnvHashMap>, deadlines: DelayQueue, } +impl Default for InFlightRequests { + fn default() -> Self { + InFlightRequests { + request_data: FnvHashMap::with_hasher(Default::default()), + deadlines: Default::default(), + } + } +} + /// Data needed to clean up a single in-flight request. #[derive(Debug)] -struct RequestData { +struct RequestData { /// Aborts the response handler for the associated request. abort_handle: AbortHandle, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, /// The client span. span: Span, + + context: C } /// An error returned when a request attempted to start with the same ID as a request already @@ -33,7 +44,7 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests { /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -44,6 +55,7 @@ impl InFlightRequests { &mut self, request_id: u64, deadline: SystemTime, + context: C, span: Span, ) -> Result { match self.request_data.entry(request_id) { @@ -55,6 +67,7 @@ impl InFlightRequests { abort_handle, deadline_key, span, + context, }); Ok(abort_registration) } @@ -66,6 +79,7 @@ impl InFlightRequests { pub fn cancel_request(&mut self, request_id: u64) -> bool { if let Some(RequestData { span, + context, abort_handle, deadline_key, }) = self.request_data.remove(&request_id) @@ -83,11 +97,11 @@ impl InFlightRequests { /// Removes a request without aborting. Returns true iff the request was found. /// This method should be used when a response is being sent. - pub fn remove_request(&mut self, request_id: u64) -> Option { + pub fn remove_request(&mut self, request_id: u64) -> Option<(C, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); - Some(request_data.span) + Some((request_data.context, request_data.span)) } else { None } @@ -117,7 +131,7 @@ impl InFlightRequests { } /// When InFlightRequests is dropped, any outstanding requests are aborted. -impl Drop for InFlightRequests { +impl Drop for InFlightRequests { fn drop(&mut self) { self.request_data .values() From 60edf4f32476bfaca335c05d67a4177d67cb51c0 Mon Sep 17 00:00:00 2001 From: Akos Vandra-Meyer Date: Sun, 12 Feb 2023 20:48:28 +0100 Subject: [PATCH 08/30] add request sequencer --- tarpc/src/client.rs | 53 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 24e678952..26e45acbf 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -22,10 +22,11 @@ use std::{ fmt, pin::Pin, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicU64, Ordering}, Arc, }, }; +use std::fmt::Debug; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -41,13 +42,26 @@ pub struct Config { /// `pending_requests_buffer` controls the size of the channel clients use /// to communicate with the request dispatch task. pub pending_request_buffer: usize, + + /// An implementation of RequestSequencer, to provide a unique series of request ids. + /// The default implementation generates 0,1,2,3,4,5,..., but this option can be leveraged + /// to generate less predictable results, using a block cipher for example. + pub request_sequencer: Arc } impl Default for Config { fn default() -> Self { + Self::with_sequencer(DefaultSequencer::default()) + } +} + +impl Config { + /// Create a default config with a specific sequencer + pub fn with_sequencer(s: S) -> Self { Config { max_in_flight_requests: 1_000, pending_request_buffer: 100, + request_sequencer: Arc::new(s) } } } @@ -90,6 +104,24 @@ const _CHECK_USIZE: () = assert!( "usize is too big to fit in u64" ); +/// Provides a stream of unique u64 numbers +pub trait RequestSequencer: Debug + Send + Sync + 'static { + + /// Generates the next number. + fn next_id(&self) -> u64; +} + +/// Default sequencer producing the numbers 0,1,2,3,4... +#[derive(Clone, Default, Debug)] +pub struct DefaultSequencer(Arc); + +impl RequestSequencer for DefaultSequencer { + fn next_id(&self) -> u64 { + println!("DEFSEQ {:?}", &self.0); + self.0.fetch_add(1, Ordering::Relaxed) + } +} + /// Handles communication from the client to request dispatch. #[derive(Debug)] pub struct Channel { @@ -97,7 +129,7 @@ pub struct Channel { /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. - next_request_id: Arc, + request_sequencer: Arc, } impl Clone for Channel { @@ -105,7 +137,7 @@ impl Clone for Channel { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), - next_request_id: self.next_request_id.clone(), + request_sequencer: self.request_sequencer.clone(), } } } @@ -126,7 +158,7 @@ impl Channel { &self, mut ctx: context::Context, request_name: &'static str, - request: Req, + request: Req ) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { @@ -137,8 +169,7 @@ impl Channel { }); span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id())); let (response_completion, mut response) = oneshot::channel(); - let request_id = - u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request_id = self.request_sequencer.next_id(); // ResponseGuard impls Drop to cancel in-flight requests. It should be created before // sending out the request; otherwise, the response future could be dropped after the @@ -249,7 +280,7 @@ where client: Channel { to_dispatch, cancellation, - next_request_id: Arc::new(AtomicUsize::new(0)), + request_sequencer: config.request_sequencer.clone(), }, dispatch: RequestDispatch { config, @@ -622,13 +653,12 @@ mod tests { use assert_matches::assert_matches; use futures::{prelude::*, task::*}; use std::{ - convert::TryFrom, pin::Pin, - sync::atomic::{AtomicUsize, Ordering}, sync::Arc, }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; + use crate::client::DefaultSequencer; #[tokio::test] async fn response_completes_request_future() { @@ -812,7 +842,7 @@ mod tests { let channel = Channel { to_dispatch, cancellation, - next_request_id: Arc::new(AtomicUsize::new(0)), + request_sequencer: Arc::new(DefaultSequencer::default()) }; (Box::pin(dispatch), channel, server_channel) @@ -824,8 +854,7 @@ mod tests { response_completion: oneshot::Sender, DeadlineExceededError>>, response: &'a mut oneshot::Receiver, DeadlineExceededError>>, ) -> ResponseGuard<'a, String> { - let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request_id = channel.request_sequencer.next_id(); let request = DispatchRequest { ctx: context::current(), span: Span::current(), From f52cd0fd0def08590acf579471e91f74a8ec24bb Mon Sep 17 00:00:00 2001 From: Akos Vandra-Meyer Date: Mon, 13 Feb 2023 00:41:09 +0100 Subject: [PATCH 09/30] revert base channel to having unit context, and move the contextual channel to a separate struct. --- tarpc/src/server.rs | 83 +++-- tarpc/src/server/contextual_channel.rs | 326 ++++++++++++++++++ tarpc/src/server/in_flight_requests.rs | 9 +- .../src/server/limits/requests_per_channel.rs | 2 + tarpc/src/server/testing.rs | 2 +- 5 files changed, 381 insertions(+), 41 deletions(-) create mode 100644 tarpc/src/server/contextual_channel.rs diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 3f71d4e24..9d04fa805 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -31,6 +31,9 @@ mod testing; /// Provides functionality to apply server limits. pub mod limits; +mod contextual_channel; + +pub use contextual_channel::*; /// Provides helper methods for streams of Channels. pub mod incoming; @@ -58,12 +61,21 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where - T: Transport<(C, Response), (C, ClientMessage)>, + T: Transport, ClientMessage>, { BaseChannel::new(self, transport) } + + /// Returns a contextual channel backed by `transport` and configured with `self`. + pub fn contextual_channel(self, transport: T) -> ContextualChannel + where + T: Transport<(C, Response), (C, ClientMessage)>, + { + ContextualChannel::new(self, transport) + } + } /// Equivalent to a `FnOnce(Req) -> impl Future`. @@ -283,7 +295,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -294,14 +306,14 @@ pub struct BaseChannel { /// Notifies `canceled_requests` when a request is canceled. request_cancellation: RequestCancellation, /// Holds data necessary to clean up in-flight requests. - in_flight_requests: InFlightRequests, + in_flight_requests: InFlightRequests<()>, /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp), C)>, + ghost: PhantomData<(fn() -> Req, fn(Resp))>, } -impl BaseChannel -where - T: Transport<(C, Response), (C, ClientMessage)>, +impl BaseChannel + where + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -331,7 +343,7 @@ where self.project().transport.get_pin_mut() } - fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<()> { self.as_mut().project().in_flight_requests } @@ -347,9 +359,8 @@ where fn start_request( mut self: Pin<&mut Self>, - request: (C, Request), + mut request: Request, ) -> Result, AlreadyExistsError> { - let (context, mut request) = request; let span = info_span!( "RPC", rpc.trace_id = %request.context.trace_id(), @@ -370,7 +381,7 @@ where let start = self.in_flight_requests_mut().start_request( request.id, request.context.deadline, - context, + (), span.clone(), ); match start { @@ -395,7 +406,7 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -581,9 +592,9 @@ pub trait Channel } } -impl Stream for BaseChannel -where - T: Transport<(C, Response), (C, ClientMessage)>, +impl Stream for BaseChannel + where + T: Transport, ClientMessage>, { type Item = Result, ChannelError>; @@ -611,7 +622,7 @@ where loop { let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { Poll::Ready(Some(request_id)) => { - if let Some((ctx, span)) = self.in_flight_requests_mut().remove_request(request_id) { + if let Some(((), span)) = self.in_flight_requests_mut().remove_request(request_id) { let _entered = span.enter(); tracing::info!("ResponseCancelled"); } @@ -640,8 +651,8 @@ where .map_err(ChannelError::Transport)? { Poll::Ready(Some(message)) => match message { - (ctx, ClientMessage::Request(request)) => { - match self.as_mut().start_request((ctx,request)) { + ClientMessage::Request(request) => { + match self.as_mut().start_request(request) { Ok(request) => return Poll::Ready(Some(Ok(request))), Err(AlreadyExistsError) => { // Instead of closing the channel if a duplicate request is sent, @@ -652,10 +663,10 @@ where } } } - (_ctx, ClientMessage::Cancel { + ClientMessage::Cancel { trace_context, request_id, - }) => { + } => { if !self.in_flight_requests_mut().cancel_request(request_id) { tracing::trace!( rpc.trace_id = %trace_context.trace_id, @@ -688,10 +699,10 @@ where } } -impl Sink> for BaseChannel -where - T: Transport<(C, Response), (C, ClientMessage)>, - T::Error: Error, +impl Sink> for BaseChannel + where + T: Transport, ClientMessage>, + T::Error: Error, { type Error = ChannelError; @@ -703,7 +714,7 @@ where } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { - if let Some((ctx, span)) = self + if let Some(((), span)) = self .in_flight_requests_mut() .remove_request(response.request_id) { @@ -711,7 +722,7 @@ where tracing::info!("SendResponse"); self.project() .transport - .start_send((ctx, response)) + .start_send(response) .map_err(ChannelError::Transport) } else { // If the request isn't tracked anymore, there's no need to send the response. @@ -735,15 +746,15 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() } } -implChannel for BaseChannel -where - T: Transport<(C, Response), (C, ClientMessage)>, +implChannel for BaseChannel + where + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; @@ -872,7 +883,7 @@ where ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { - Some((c, response)) => Poll::Ready(Some(Ok(response))), + Some(((), response)) => Poll::Ready(Some(Ok(response))), None => { // This branch likely won't happen, since the Requests stream is holding a Sender. Poll::Ready(None) @@ -1537,10 +1548,10 @@ mod tests { .as_mut() .project() .responses_tx - .send(Response { + .send(((), Response { request_id: 1, message: Ok(()), - }) + })) .await .unwrap(); @@ -1597,10 +1608,10 @@ mod tests { .as_mut() .project() .responses_tx - .send(Response { + .send(((), Response { request_id: 1, message: Ok(()), - }) + })) .await .unwrap(); diff --git a/tarpc/src/server/contextual_channel.rs b/tarpc/src/server/contextual_channel.rs new file mode 100644 index 000000000..a3ae37be0 --- /dev/null +++ b/tarpc/src/server/contextual_channel.rs @@ -0,0 +1,326 @@ +use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, + context::{SpanExt}, + trace, ClientMessage, Request, Response, Transport, +}; +use futures::{ + prelude::*, + stream::Fuse, + task::*, +}; +use super::in_flight_requests::{AlreadyExistsError, InFlightRequests}; +use pin_project::pin_project; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; +use tracing::{info_span}; +use crate::server::{Channel, ChannelError, Config, ResponseGuard, TrackedRequest}; + +/// BaseChannel is the standard implementation of a [`Channel`]. +/// +/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and +/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for +/// how to use channels. +/// +/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation +/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation +/// messages. Instead, it internally handles them by cancelling corresponding requests (removing +/// the corresponding in-flight requests and aborting their handlers). +#[pin_project] +pub struct ContextualChannel { + config: Config, + /// Writes responses to the wire and reads requests off the wire. + #[pin] + transport: Fuse, + /// In-flight requests that were dropped by the server before completion. + #[pin] + canceled_requests: CanceledRequests, + /// Notifies `canceled_requests` when a request is canceled. + request_cancellation: RequestCancellation, + /// Holds data necessary to clean up in-flight requests. + in_flight_requests: InFlightRequests, + /// Types the request and response. + ghost: PhantomData<(fn() -> Req, fn(Resp))>, +} + +impl ContextualChannel + where + T: Transport<(C, Response), (C, ClientMessage)>, +{ + /// Creates a new channel backed by `transport` and configured with `config`. + pub fn new(config: Config, transport: T) -> Self { + let (request_cancellation, canceled_requests) = cancellations(); + ContextualChannel { + config, + transport: transport.fuse(), + canceled_requests, + request_cancellation, + in_flight_requests: InFlightRequests::default(), + ghost: PhantomData, + } + } + + /// Creates a new channel backed by `transport` and configured with the defaults. + pub fn with_defaults(transport: T) -> Self { + Self::new(Config::default(), transport) + } + + /// Returns the inner transport over which messages are sent and received. + pub fn get_ref(&self) -> &T { + self.transport.get_ref() + } + + /// Returns the inner transport over which messages are sent and received. + pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().transport.get_pin_mut() + } + + fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + self.as_mut().project().in_flight_requests + } + + fn canceled_requests_pin_mut<'a>( + self: &'a mut Pin<&mut Self>, + ) -> Pin<&'a mut CanceledRequests> { + self.as_mut().project().canceled_requests + } + + fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { + self.as_mut().project().transport + } + + pub(super) fn start_request( + mut self: Pin<&mut Self>, + request: (C, Request), + ) -> Result, AlreadyExistsError> { + let (context, mut request) = request; + let span = info_span!( + "RPC", + rpc.trace_id = %request.context.trace_id(), + rpc.deadline = %humantime::format_rfc3339(request.context.deadline), + otel.kind = "server", + otel.name = tracing::field::Empty, + ); + span.set_context(&request.context); + request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + tracing::trace!( + "OpenTelemetry subscriber not installed; making unsampled \ + child context." + ); + request.context.trace_context.new_child() + }); + let entered = span.enter(); + tracing::info!("ReceiveRequest"); + let start = self.in_flight_requests_mut().start_request( + request.id, + request.context.deadline, + context, + span.clone(), + ); + match start { + Ok(abort_registration) => { + drop(entered); + Ok(TrackedRequest { + abort_registration, + span, + response_guard: ResponseGuard { + request_id: request.id, + request_cancellation: self.request_cancellation.clone(), + cancel: false, + }, + request, + }) + } + Err(AlreadyExistsError) => { + tracing::trace!("DuplicateRequest"); + Err(AlreadyExistsError) + } + } + } +} + +impl fmt::Debug for ContextualChannel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "BaseChannel") + } +} + +impl Stream for ContextualChannel + where + T: Transport<(C, Response), (C, ClientMessage)>, +{ + type Item = Result, ChannelError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + #[derive(Clone, Copy, Debug)] + enum ReceiverStatus { + Ready, + Pending, + Closed, + } + + impl ReceiverStatus { + fn combine(self, other: Self) -> Self { + use ReceiverStatus::*; + match (self, other) { + (Ready, _) | (_, Ready) => Ready, + (Closed, Closed) => Closed, + (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending, + } + } + } + + use ReceiverStatus::*; + + loop { + let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { + Poll::Ready(Some(request_id)) => { + if let Some((_ctx, span)) = self.in_flight_requests_mut().remove_request(request_id) { + let _entered = span.enter(); + tracing::info!("ResponseCancelled"); + } + Ready + } + // Pending cancellations don't block Channel closure, because all they do is ensure + // the Channel's internal state is cleaned up. But Channel closure also cleans up + // the Channel state, so there's no reason to wait on a cancellation before + // closing. + // + // Ready(None) can't happen, since `self` holds a Cancellation. + Poll::Pending | Poll::Ready(None) => Closed, + }; + + let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) { + // No need to send a response, since the client wouldn't be waiting for one + // anymore. + Poll::Ready(Some(_)) => Ready, + Poll::Ready(None) => Closed, + Poll::Pending => Pending, + }; + + let request_status = match self + .transport_pin_mut() + .poll_next(cx) + .map_err(ChannelError::Transport)? + { + Poll::Ready(Some(message)) => match message { + (ctx, ClientMessage::Request(request)) => { + match self.as_mut().start_request((ctx,request)) { + Ok(request) => return Poll::Ready(Some(Ok(request))), + Err(AlreadyExistsError) => { + // Instead of closing the channel if a duplicate request is sent, + // just ignore it, since it's already being processed. Note that we + // cannot return Poll::Pending here, since nothing has scheduled a + // wakeup yet. + continue; + } + } + } + (_ctx, ClientMessage::Cancel { + trace_context, + request_id, + }) => { + if !self.in_flight_requests_mut().cancel_request(request_id) { + tracing::trace!( + rpc.trace_id = %trace_context.trace_id, + "Received cancellation, but response handler is already complete.", + ); + } + Ready + } + }, + Poll::Ready(None) => Closed, + Poll::Pending => Pending, + }; + + let status = cancellation_status + .combine(expiration_status) + .combine(request_status); + + tracing::trace!( + "Cancellations: {cancellation_status:?}, \ + Expired requests: {expiration_status:?}, \ + Inbound: {request_status:?}, \ + Overall: {status:?}", + ); + match status { + Ready => continue, + Closed => return Poll::Ready(None), + Pending => return Poll::Pending, + } + } + } +} + +impl Sink> for ContextualChannel + where + T: Transport<(C, Response), (C, ClientMessage)>, + T::Error: Error, +{ + type Error = ChannelError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project() + .transport + .poll_ready(cx) + .map_err(ChannelError::Transport) + } + + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + if let Some((ctx, span)) = self + .in_flight_requests_mut() + .remove_request(response.request_id) + { + let _entered = span.enter(); + tracing::info!("SendResponse"); + self.project() + .transport + .start_send((ctx, response)) + .map_err(ChannelError::Transport) + } else { + // If the request isn't tracked anymore, there's no need to send the response. + Ok(()) + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + tracing::trace!("poll_flush"); + self.project() + .transport + .poll_flush(cx) + .map_err(ChannelError::Transport) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project() + .transport + .poll_close(cx) + .map_err(ChannelError::Transport) + } +} + +impl AsRef for ContextualChannel { + fn as_ref(&self) -> &T { + self.transport.get_ref() + } +} + +implChannel for ContextualChannel + where + T: Transport<(C, Response), (C, ClientMessage)>, +{ + type Req = Req; + type Resp = Resp; + type Transport = T; + + + fn config(&self) -> &Config { + &self.config + } + + fn in_flight_requests(&self) -> usize { + self.in_flight_requests.len() + } + + fn transport(&self) -> &Self::Transport { + self.get_ref() + } +} diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index c101e1456..1b0b12f0a 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -79,9 +79,9 @@ impl InFlightRequests { pub fn cancel_request(&mut self, request_id: u64) -> bool { if let Some(RequestData { span, - context, abort_handle, deadline_key, + .. }) = self.request_data.remove(&request_id) { let _entered = span.enter(); @@ -155,7 +155,7 @@ mod tests { let mut in_flight_requests = InFlightRequests::default(); assert_eq!(in_flight_requests.len(), 0); in_flight_requests - .start_request(0, SystemTime::now(), Span::current()) + .start_request(0, SystemTime::now(), (), Span::current()) .unwrap(); assert_eq!(in_flight_requests.len(), 1); } @@ -164,7 +164,7 @@ mod tests { async fn polling_expired_aborts() { let mut in_flight_requests = InFlightRequests::default(); let abort_registration = in_flight_requests - .start_request(0, SystemTime::now(), Span::current()) + .start_request(0, SystemTime::now(), (), Span::current()) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); @@ -186,7 +186,7 @@ mod tests { async fn cancel_request_aborts() { let mut in_flight_requests = InFlightRequests::default(); let abort_registration = in_flight_requests - .start_request(0, SystemTime::now(), Span::current()) + .start_request(0, SystemTime::now(), (), Span::current()) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); @@ -207,6 +207,7 @@ mod tests { .start_request( 0, SystemTime::now() + std::time::Duration::from_secs(10), + (), Span::current(), ) .unwrap(); diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 3c29836ab..3f668878c 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -204,6 +204,7 @@ mod tests { .start_request( i, SystemTime::now() + Duration::from_secs(1), + (), Span::current(), ) .unwrap(); @@ -327,6 +328,7 @@ mod tests { .start_request( 0, SystemTime::now() + Duration::from_secs(1), + (), Span::current(), ) .unwrap(); diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 938865c0f..4c91d0730 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -22,7 +22,7 @@ pub(crate) struct FakeChannel { #[pin] pub sink: VecDeque, pub config: Config, - pub in_flight_requests: super::in_flight_requests::InFlightRequests, + pub in_flight_requests: super::in_flight_requests::InFlightRequests<()>, pub request_cancellation: RequestCancellation, pub canceled_requests: CanceledRequests, } From 416090f66d3b9f494ea19b9d059290989f0a4799 Mon Sep 17 00:00:00 2001 From: Akos Vandra-Meyer Date: Mon, 13 Feb 2023 00:48:10 +0100 Subject: [PATCH 10/30] fix merge differences --- tarpc/src/server.rs | 67 +++++++++++++------------- tarpc/src/server/in_flight_requests.rs | 2 +- 2 files changed, 35 insertions(+), 34 deletions(-) diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 9d04fa805..73aa04ff1 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -31,8 +31,10 @@ mod testing; /// Provides functionality to apply server limits. pub mod limits; +// mod base_channel; mod contextual_channel; +// pub use base_channel::*; pub use contextual_channel::*; /// Provides helper methods for streams of Channels. @@ -312,8 +314,8 @@ pub struct BaseChannel { } impl BaseChannel - where - T: Transport, ClientMessage>, +where + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -426,20 +428,6 @@ pub struct TrackedRequest { pub response_guard: ResponseGuard, } -/// Critical errors that result in a Channel disconnecting. -#[derive(thiserror::Error, Debug)] -pub enum ChannelError -where - E: Error + Send + Sync + 'static, -{ - /// An error occurred reading from, or writing to, the transport. - #[error("an error occurred in the transport")] - Transport(#[source] E), - /// An error occurred while polling expired requests. - #[error("an error occurred while polling expired requests")] - Timer(#[source] ::tokio::time::error::Error), -} - /// The server end of an open connection with a client, receiving requests from, and sending /// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management. /// @@ -466,8 +454,8 @@ where /// `TrackedRequest` is to get one from another `Channel`. Ultimately, all `TrackedRequests` are /// created by [`BaseChannel`]. pub trait Channel - where - Self: Transport::Resp>, TrackedRequest<::Req>>, +where + Self: Transport::Resp>, TrackedRequest<::Req>>, { /// Type of request item. type Req; @@ -498,8 +486,8 @@ pub trait Channel self, limit: usize, ) -> limits::requests_per_channel::MaxRequests - where - Self: Sized, + where + Self: Sized, { limits::requests_per_channel::MaxRequests::new(self, limit) } @@ -539,8 +527,8 @@ pub trait Channel /// } /// ``` fn requests(self) -> Requests - where - Self: Sized, + where + Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); @@ -584,17 +572,31 @@ pub trait Channel /// } /// ``` fn execute(self, serve: S) -> impl Stream> - where - Self: Sized, - S: Serve + Clone, + where + Self: Sized, + S: Serve + Clone, { self.requests().execute(serve) } } -impl Stream for BaseChannel +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError where - T: Transport, ClientMessage>, + E: Error + Send + Sync + 'static, +{ + /// An error occurred reading from, or writing to, the transport. + #[error("an error occurred in the transport")] + Transport(#[source] E), + /// An error occurred while polling expired requests. + #[error("an error occurred while polling expired requests")] + Timer(#[source] ::tokio::time::error::Error), +} + +impl Stream for BaseChannel +where + T: Transport, ClientMessage>, { type Item = Result, ChannelError>; @@ -700,9 +702,9 @@ impl Stream for BaseChannel } impl Sink> for BaseChannel - where - T: Transport, ClientMessage>, - T::Error: Error, +where + T: Transport, ClientMessage>, + T::Error: Error, { type Error = ChannelError; @@ -753,14 +755,13 @@ impl AsRef for BaseChannel { } implChannel for BaseChannel - where - T: Transport, ClientMessage>, +where + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; type Transport = T; - fn config(&self) -> &Config { &self.config } diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index 1b0b12f0a..ef535fd7f 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -35,7 +35,7 @@ struct RequestData { deadline_key: delay_queue::Key, /// The client span. span: Span, - + /// Optional server side context of kept for the lifecycle of the request context: C } From d358a83ea6d805b4a83ca9f96662b6807ae4e403 Mon Sep 17 00:00:00 2001 From: Akos Vandra-Meyer Date: Mon, 13 Feb 2023 08:19:32 +0100 Subject: [PATCH 11/30] revert atomics change to u64 --- tarpc/src/client.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 26e45acbf..cdd252bf1 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -22,7 +22,7 @@ use std::{ fmt, pin::Pin, sync::{ - atomic::{AtomicU64, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, }; @@ -99,11 +99,6 @@ impl fmt::Debug for NewClient { } } -const _CHECK_USIZE: () = assert!( - std::mem::size_of::() <= std::mem::size_of::(), - "usize is too big to fit in u64" -); - /// Provides a stream of unique u64 numbers pub trait RequestSequencer: Debug + Send + Sync + 'static { @@ -111,14 +106,19 @@ pub trait RequestSequencer: Debug + Send + Sync + 'static { fn next_id(&self) -> u64; } +const _CHECK_USIZE: () = assert!( + std::mem::size_of::() <= std::mem::size_of::(), + "usize is too big to fit in u64" +); /// Default sequencer producing the numbers 0,1,2,3,4... #[derive(Clone, Default, Debug)] -pub struct DefaultSequencer(Arc); +pub struct DefaultSequencer(Arc); impl RequestSequencer for DefaultSequencer { fn next_id(&self) -> u64 { - println!("DEFSEQ {:?}", &self.0); - self.0.fetch_add(1, Ordering::Relaxed) + //_CHECK_USIZE verifies that usize fits into an u64, and usize atomics are more likely(?) be present + // than u64 on smaller architectures. + self.0.fetch_add(1, Ordering::Relaxed) as u64 } } From c17652824054e63c295ecbc8af909e2b440d308a Mon Sep 17 00:00:00 2001 From: Akos Vandra-Meyer Date: Mon, 13 Feb 2023 08:23:21 +0100 Subject: [PATCH 12/30] fix formatting --- tarpc/src/client.rs | 5 ++--- tarpc/src/server.rs | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index cdd252bf1..02ef341d3 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -26,7 +26,6 @@ use std::{ Arc, }, }; -use std::fmt::Debug; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -100,7 +99,7 @@ impl fmt::Debug for NewClient { } /// Provides a stream of unique u64 numbers -pub trait RequestSequencer: Debug + Send + Sync + 'static { +pub trait RequestSequencer: fmt::Debug + Send + Sync + 'static { /// Generates the next number. fn next_id(&self) -> u64; @@ -887,7 +886,7 @@ mod tests { impl PollTest for Poll>> where - E: ::std::fmt::Display, + E: fmt::Display, { type T = Option; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 73aa04ff1..74f00f250 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -583,8 +583,8 @@ where /// Critical errors that result in a Channel disconnecting. #[derive(thiserror::Error, Debug)] pub enum ChannelError - where - E: Error + Send + Sync + 'static, +where + E: Error + Send + Sync + 'static, { /// An error occurred reading from, or writing to, the transport. #[error("an error occurred in the transport")] @@ -754,7 +754,7 @@ impl AsRef for BaseChannel { } } -implChannel for BaseChannel +impl Channel for BaseChannel where T: Transport, ClientMessage>, { From 33ce197eb1bb1a3876030182cda82bf32c5be6df Mon Sep 17 00:00:00 2001 From: Akos Vandra-Meyer Date: Thu, 16 Feb 2023 11:12:58 +0100 Subject: [PATCH 13/30] wip --- tarpc/src/server/base_channel.rs | 325 +++++++++++++++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 tarpc/src/server/base_channel.rs diff --git a/tarpc/src/server/base_channel.rs b/tarpc/src/server/base_channel.rs new file mode 100644 index 000000000..157576480 --- /dev/null +++ b/tarpc/src/server/base_channel.rs @@ -0,0 +1,325 @@ +use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, + context::{SpanExt}, + trace, ClientMessage, Request, Response, Transport, +}; +use futures::{ + prelude::*, + stream::Fuse, + task::*, +}; +use super::in_flight_requests::{AlreadyExistsError, InFlightRequests}; +use pin_project::pin_project; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; +use tracing::{info_span}; +use crate::server::{Channel, ChannelError, Config, ResponseGuard, TrackedRequest}; + +/// BaseChannel is the standard implementation of a [`Channel`]. +/// +/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and +/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for +/// how to use channels. +/// +/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation +/// mssages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation +/// messages. Instead, it internally handles them by cancelling corresponding requests (removing +/// the corresponding in-flight requests and aborting their handlers). +#[pin_project] +pub struct BaseChannel { + config: Config, + /// Writes responses to the wire and reads requests off the wire. + #[pin] + transport: Fuse, + /// In-flight requests that were dropped by the server before completion. + #[pin] + pub(super) canceled_requests: CanceledRequests, + /// Notifies `canceled_requests` when a request is canceled. + request_cancellation: RequestCancellation, + /// Holds data necessary to clean up in-flight requests. + in_flight_requests: InFlightRequests<()>, + /// Types the request and response. + ghost: PhantomData<(fn() -> Req, fn(Resp))>, +} + +impl BaseChannel + where + T: Transport, ClientMessage>, +{ + /// Creates a new channel backed by `transport` and configured with `config`. + pub fn new(config: Config, transport: T) -> Self { + let (request_cancellation, canceled_requests) = cancellations(); + BaseChannel { + config, + transport: transport.fuse(), + canceled_requests, + request_cancellation, + in_flight_requests: InFlightRequests::default(), + ghost: PhantomData, + } + } + + /// Creates a new channel backed by `transport` and configured with the defaults. + pub fn with_defaults(transport: T) -> Self { + Self::new(Config::default(), transport) + } + + /// Returns the inner transport over which messages are sent and received. + pub fn get_ref(&self) -> &T { + self.transport.get_ref() + } + + /// Returns the inner transport over which messages are sent and received. + pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().transport.get_pin_mut() + } + + fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<()> { + self.as_mut().project().in_flight_requests + } + + fn canceled_requests_pin_mut<'a>( + self: &'a mut Pin<&mut Self>, + ) -> Pin<&'a mut CanceledRequests> { + self.as_mut().project().canceled_requests + } + + fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { + self.as_mut().project().transport + } + + pub(super) fn start_request( + mut self: Pin<&mut Self>, + mut request: Request, + ) -> Result, AlreadyExistsError> { + let span = info_span!( + "RPC", + rpc.trace_id = %request.context.trace_id(), + rpc.deadline = %humantime::format_rfc3339(request.context.deadline), + otel.kind = "server", + otel.name = tracing::field::Empty, + ); + span.set_context(&request.context); + request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + tracing::trace!( + "OpenTelemetry subscriber not installed; making unsampled \ + child context." + ); + request.context.trace_context.new_child() + }); + let entered = span.enter(); + tracing::info!("ReceiveRequest"); + let start = self.in_flight_requests_mut().start_request( + request.id, + request.context.deadline, + (), + span.clone(), + ); + match start { + Ok(abort_registration) => { + drop(entered); + Ok(TrackedRequest { + abort_registration, + span, + response_guard: ResponseGuard { + request_id: request.id, + request_cancellation: self.request_cancellation.clone(), + cancel: false, + }, + request, + }) + } + Err(AlreadyExistsError) => { + tracing::trace!("DuplicateRequest"); + Err(AlreadyExistsError) + } + } + } +} + +impl fmt::Debug for BaseChannel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "BaseChannel") + } +} + +impl Stream for BaseChannel + where + T: Transport, ClientMessage>, +{ + type Item = Result, ChannelError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + #[derive(Clone, Copy, Debug)] + enum ReceiverStatus { + Ready, + Pending, + Closed, + } + + impl ReceiverStatus { + fn combine(self, other: Self) -> Self { + use ReceiverStatus::*; + match (self, other) { + (Ready, _) | (_, Ready) => Ready, + (Closed, Closed) => Closed, + (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending, + } + } + } + + use ReceiverStatus::*; + + loop { + let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { + Poll::Ready(Some(request_id)) => { + if let Some(((), span)) = self.in_flight_requests_mut().remove_request(request_id) { + let _entered = span.enter(); + tracing::info!("ResponseCancelled"); + } + Ready + } + // Pending cancellations don't block Channel closure, because all they do is ensure + // the Channel's internal state is cleaned up. But Channel closure also cleans up + // the Channel state, so there's no reason to wait on a cancellation before + // closing. + // + // Ready(None) can't happen, since `self` holds a Cancellation. + Poll::Pending | Poll::Ready(None) => Closed, + }; + + let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) { + // No need to send a response, since the client wouldn't be waiting for one + // anymore. + Poll::Ready(Some(_)) => Ready, + Poll::Ready(None) => Closed, + Poll::Pending => Pending, + }; + + let request_status = match self + .transport_pin_mut() + .poll_next(cx) + .map_err(ChannelError::Transport)? + { + Poll::Ready(Some(message)) => match message { + ClientMessage::Request(request) => { + match self.as_mut().start_request(request) { + Ok(request) => return Poll::Ready(Some(Ok(request))), + Err(AlreadyExistsError) => { + // Instead of closing the channel if a duplicate request is sent, + // just ignore it, since it's already being processed. Note that we + // cannot return Poll::Pending here, since nothing has scheduled a + // wakeup yet. + continue; + } + } + } + ClientMessage::Cancel { + trace_context, + request_id, + } => { + if !self.in_flight_requests_mut().cancel_request(request_id) { + tracing::trace!( + rpc.trace_id = %trace_context.trace_id, + "Received cancellation, but response handler is already complete.", + ); + } + Ready + } + }, + Poll::Ready(None) => Closed, + Poll::Pending => Pending, + }; + + let status = cancellation_status + .combine(expiration_status) + .combine(request_status); + + tracing::trace!( + "Cancellations: {cancellation_status:?}, \ + Expired requests: {expiration_status:?}, \ + Inbound: {request_status:?}, \ + Overall: {status:?}", + ); + match status { + Ready => continue, + Closed => return Poll::Ready(None), + Pending => return Poll::Pending, + } + } + } +} + +impl Sink> for BaseChannel + where + T: Transport, ClientMessage>, + T::Error: Error, +{ + type Error = ChannelError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project() + .transport + .poll_ready(cx) + .map_err(ChannelError::Transport) + } + + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + if let Some(((), span)) = self + .in_flight_requests_mut() + .remove_request(response.request_id) + { + let _entered = span.enter(); + tracing::info!("SendResponse"); + self.project() + .transport + .start_send(response) + .map_err(ChannelError::Transport) + } else { + // If the request isn't tracked anymore, there's no need to send the response. + Ok(()) + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + tracing::trace!("poll_flush"); + self.project() + .transport + .poll_flush(cx) + .map_err(ChannelError::Transport) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project() + .transport + .poll_close(cx) + .map_err(ChannelError::Transport) + } +} + +impl AsRef for BaseChannel { + fn as_ref(&self) -> &T { + self.transport.get_ref() + } +} + +implChannel for BaseChannel + where + T: Transport, ClientMessage>, +{ + type Req = Req; + type Resp = Resp; + type Transport = T; + + + fn config(&self) -> &Config { + &self.config + } + + fn in_flight_requests(&self) -> usize { + self.in_flight_requests.len() + } + + fn transport(&self) -> &Self::Transport { + self.get_ref() + } +} From 61aa139e1b35e0859b694cdeb76b5e76d46efe25 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sat, 12 Nov 2022 17:24:40 -0800 Subject: [PATCH 14/30] Add back the Client trait, renamed Stub. Also adds a Client stub trait alias for each generated service. Now that generic associated types are stable, it's almost possible to define a trait for Channel that works with async fns on stable. `impl trait in type aliases` is still necessary (and unstable), but we're getting closer. As a proof of concept, three more implementations of Stub are implemented; 1. A load balancer that round-robins requests between different stubs. 2. A load balancer that selects a stub based on a request hash, so that the same requests go to the same stubs. 3. A stub that retries requests based on a configurable policy. The "serde/rc" feature is added to the "full" feature because the Retry stub wraps the request in an Arc, so that the request is reusable for multiple calls. Server implementors commonly need to operate generically across all services or request types. For example, a server throttler may want to return errors telling clients to back off, which is not specific to any one service. --- plugins/src/lib.rs | 38 +++- tarpc/Cargo.toml | 2 +- tarpc/src/client.rs | 1 + tarpc/src/client/stub.rs | 56 +++++ tarpc/src/client/stub/load_balance.rs | 305 ++++++++++++++++++++++++++ tarpc/src/client/stub/mock.rs | 54 +++++ tarpc/src/client/stub/retry.rs | 75 +++++++ tarpc/src/lib.rs | 28 +-- tarpc/src/server.rs | 43 +++- 9 files changed, 573 insertions(+), 29 deletions(-) create mode 100644 tarpc/src/client/stub.rs create mode 100644 tarpc/src/client/stub/load_balance.rs create mode 100644 tarpc/src/client/stub/mock.rs create mode 100644 tarpc/src/client/stub/retry.rs diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 1b83c3247..003efc97b 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -276,6 +276,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { ServiceGenerator { response_fut_name, service_ident: ident, + client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), @@ -432,6 +433,7 @@ fn verify_types_were_provided( // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, + client_stub_ident: &'a Ident, server_ident: &'a Ident, response_fut_ident: &'a Ident, response_fut_name: &'a str, @@ -461,6 +463,9 @@ impl<'a> ServiceGenerator<'a> { future_types, return_types, service_ident, + client_stub_ident, + request_ident, + response_ident, server_ident, .. } = self; @@ -490,6 +495,7 @@ impl<'a> ServiceGenerator<'a> { }, ); + let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: Sized { @@ -501,6 +507,15 @@ impl<'a> ServiceGenerator<'a> { #server_ident { service: self } } } + + #[doc = #stub_doc] + #vis trait #client_stub_ident: tarpc::client::stub::Stub { + } + + impl #client_stub_ident for S + where S: tarpc::client::stub::Stub + { + } } } @@ -666,7 +681,7 @@ impl<'a> ServiceGenerator<'a> { #response_fut_ident::#camel_case_idents(resp) => std::pin::Pin::new_unchecked(resp) .poll(cx) - .map(#response_ident::#camel_case_idents), + .map(#response_ident::#camel_case_idents) )* } } @@ -689,7 +704,9 @@ impl<'a> ServiceGenerator<'a> { #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](std::future::Future). - #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); + #vis struct #client_ident< + Stub = tarpc::client::Channel<#request_ident, #response_ident> + >(Stub); } } @@ -719,6 +736,17 @@ impl<'a> ServiceGenerator<'a> { dispatch: new_client.dispatch, } } + } + + impl From for #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { + /// Returns a new client stub that sends requests over the given transport. + fn from(stub: Stub) -> Self { + #client_ident(stub) + } } } @@ -741,7 +769,11 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl #client_ident { + impl #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { #( #[allow(unused)] #( #method_attrs )* diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 97ac95232..c6f80644e 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use." [features] default = [] -serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] +serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"] tokio1 = ["tokio/rt"] serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] serde-transport-json = ["tokio-serde/json"] diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 109ee8ff2..fc376e4d6 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -7,6 +7,7 @@ //! Provides a client that connects to a server and sends multiplexed requests. mod in_flight_requests; +pub mod stub; use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs new file mode 100644 index 000000000..a8b72a20f --- /dev/null +++ b/tarpc/src/client/stub.rs @@ -0,0 +1,56 @@ +//! Provides a Stub trait, implemented by types that can call remote services. + +use crate::{ + client::{Channel, RpcError}, + context, +}; +use futures::prelude::*; + +pub mod load_balance; +pub mod retry; + +#[cfg(test)] +mod mock; + +/// A connection to a remote service. +/// Calls the service with requests of type `Req` and receives responses of type `Resp`. +pub trait Stub { + /// The service request type. + type Req; + + /// The service response type. + type Resp; + + /// The type of the future returned by `Stub::call`. + type RespFut<'a>: Future> + where + Self: 'a, + Self::Req: 'a, + Self::Resp: 'a; + + /// Calls a remote service. + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a>; +} + +impl Stub for Channel { + type Req = Req; + type Resp = Resp; + type RespFut<'a> = RespFut<'a, Req, Resp> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Req: 'a, Resp: 'a> = impl Future> + 'a; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs new file mode 100644 index 000000000..c9005a423 --- /dev/null +++ b/tarpc/src/client/stub/load_balance.rs @@ -0,0 +1,305 @@ +//! Provides load-balancing [Stubs](crate::client::stub::Stub). + +pub use consistent_hash::ConsistentHash; +pub use round_robin::RoundRobin; + +/// Provides a stub that load-balances with a simple round-robin strategy. +mod round_robin { + use crate::{ + client::{stub, RpcError}, + context, + }; + use cycle::AtomicCycle; + use futures::prelude::*; + + impl stub::Stub for RoundRobin + where + Stub: stub::Stub, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct RoundRobin { + stubs: AtomicCycle, + } + + impl RoundRobin + where + Stub: stub::Stub, + { + /// Returns a new RoundRobin stub. + pub fn new(stubs: Vec) -> Self { + Self { + stubs: AtomicCycle::new(stubs), + } + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let next = self.stubs.next(); + next.call(ctx, request_name, request).await + } + } + + mod cycle { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + /// Cycles endlessly and atomically over a collection of elements of type T. + #[derive(Clone, Debug)] + pub struct AtomicCycle(Arc>); + + #[derive(Debug)] + struct State { + elements: Vec, + next: AtomicUsize, + } + + impl AtomicCycle { + pub fn new(elements: Vec) -> Self { + Self(Arc::new(State { + elements, + next: Default::default(), + })) + } + + pub fn next(&self) -> &T { + self.0.next() + } + } + + impl State { + pub fn next(&self) -> &T { + let next = self.next.fetch_add(1, Ordering::Relaxed); + &self.elements[next % self.elements.len()] + } + } + + #[test] + fn test_cycle() { + let cycle = AtomicCycle::new(vec![1, 2, 3]); + assert_eq!(cycle.next(), &1); + assert_eq!(cycle.next(), &2); + assert_eq!(cycle.next(), &3); + assert_eq!(cycle.next(), &1); + } + } +} + +/// Provides a stub that load-balances with a consistent hashing strategy. +/// +/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use +/// the same stub. +mod consistent_hash { + use crate::{ + client::{stub, RpcError}, + context, + }; + use futures::prelude::*; + use std::{ + collections::hash_map::RandomState, + hash::{BuildHasher, Hash, Hasher}, + num::TryFromIntError, + }; + + impl stub::Stub for ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct ConsistentHash { + stubs: Vec, + stubs_len: u64, + hasher: S, + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn new(stubs: Vec) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher: RandomState::new(), + }) + } + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + S: BuildHasher, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn with_hasher(stubs: Vec, hasher: S) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher, + }) + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect( + "invariant broken: stubs_len is not larger than a usize, \ + so the hash modulo stubs_len should always fit in a usize", + ); + let next = &self.stubs[index]; + next.call(ctx, request_name, request).await + } + + fn hash_request(&self, req: &Stub::Req) -> u64 { + let mut hasher = self.hasher.build_hasher(); + req.hash(&mut hasher); + hasher.finish() + } + } + + #[cfg(test)] + mod tests { + use super::ConsistentHash; + use crate::{client::stub::mock::Mock, context}; + use std::{ + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + rc::Rc, + }; + + #[tokio::test] + async fn test() -> anyhow::Result<()> { + let stub = ConsistentHash::with_hasher( + vec![ + // For easier reading of the assertions made in this test, each Mock's response + // value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 % + // 3 = 1, etc. + Mock::new([('a', 3), ('b', 3), ('c', 3)]), + Mock::new([('a', 1), ('b', 1), ('c', 1)]), + Mock::new([('a', 2), ('b', 2), ('c', 2)]), + ], + FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]), + )?; + + for _ in 0..2 { + let resp = stub.call(context::current(), "", 'a').await?; + assert_eq!(resp, 1); + + let resp = stub.call(context::current(), "", 'b').await?; + assert_eq!(resp, 2); + + let resp = stub.call(context::current(), "", 'c').await?; + assert_eq!(resp, 3); + } + + Ok(()) + } + + struct HashRecorder(Vec); + impl Hasher for HashRecorder { + fn write(&mut self, bytes: &[u8]) { + self.0 = Vec::from(bytes); + } + fn finish(&self) -> u64 { + 0 + } + } + + struct FakeHasherBuilder { + recorded_hashes: Rc, u64>>, + } + + struct FakeHasher { + recorded_hashes: Rc, u64>>, + output: u64, + } + + impl BuildHasher for FakeHasherBuilder { + type Hasher = FakeHasher; + + fn build_hasher(&self) -> Self::Hasher { + FakeHasher { + recorded_hashes: self.recorded_hashes.clone(), + output: 0, + } + } + } + + impl FakeHasherBuilder { + fn new(fake_hashes: [(T, u64); N]) -> Self { + let mut recorded_hashes = HashMap::new(); + for (to_hash, fake_hash) in fake_hashes { + let mut recorder = HashRecorder(vec![]); + to_hash.hash(&mut recorder); + recorded_hashes.insert(recorder.0, fake_hash); + } + Self { + recorded_hashes: Rc::new(recorded_hashes), + } + } + } + + impl Hasher for FakeHasher { + fn write(&mut self, bytes: &[u8]) { + if let Some(hash) = self.recorded_hashes.get(bytes) { + self.output = *hash; + } + } + fn finish(&self) -> u64 { + self.output + } + } + } +} diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs new file mode 100644 index 000000000..99a54422f --- /dev/null +++ b/tarpc/src/client/stub/mock.rs @@ -0,0 +1,54 @@ +use crate::{ + client::{stub::Stub, RpcError}, + context, ServerError, +}; +use futures::future; +use std::{collections::HashMap, hash::Hash, io}; + +/// A mock stub that returns user-specified responses. +pub struct Mock { + responses: HashMap, +} + +impl Mock +where + Req: Eq + Hash, +{ + /// Returns a new mock, mocking the specified (request, response) pairs. + pub fn new(responses: [(Req, Resp); N]) -> Self { + Self { + responses: HashMap::from(responses), + } + } +} + +impl Stub for Mock +where + Req: Eq + Hash, + Resp: Clone, +{ + type Req = Req; + type Resp = Resp; + type RespFut<'a> = future::Ready> + where Self: 'a; + + fn call<'a>( + &'a self, + _: context::Context, + _: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + future::ready( + self.responses + .get(&request) + .cloned() + .map(Ok) + .unwrap_or_else(|| { + Err(RpcError::Server(ServerError { + kind: io::ErrorKind::NotFound, + detail: "mock (request, response) entry not found".into(), + })) + }), + ) + } +} diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs new file mode 100644 index 000000000..46ad09685 --- /dev/null +++ b/tarpc/src/client/stub/retry.rs @@ -0,0 +1,75 @@ +//! Provides a stub that retries requests based on response contents.. + +use crate::{ + client::{stub, RpcError}, + context, +}; +use futures::prelude::*; +use std::sync::Arc; + +impl stub::Stub for Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + type Req = Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub, Self::Req, F> + where Self: 'a, + Self::Req: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> = + impl Future> + 'a; + +/// A Stub that retries requests based on response contents. +/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled. +#[derive(Clone, Debug)] +pub struct Retry { + should_retry: F, + stub: Stub, +} + +impl Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + /// Creates a new Retry stub that delegates calls to the underlying `stub`. + pub fn new(stub: Stub, should_retry: F) -> Self { + Self { stub, should_retry } + } + + async fn call<'a, 'b>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Result + where + Req: 'b, + { + let request = Arc::new(request); + for i in 1.. { + let result = self + .stub + .call(ctx, request_name, Arc::clone(&request)) + .await; + if (self.should_retry)(&result, i) { + tracing::trace!("Retrying on attempt {i}"); + continue; + } + return result; + } + unreachable!("Wow, that was a lot of attempts!"); + } +} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 418cedd82..280da694e 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -200,6 +200,7 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. +#![feature(type_alias_impl_trait)] #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -311,7 +312,6 @@ pub use crate::transport::sealed::Transport; use anyhow::Context as _; use futures::task::*; -use std::sync::Arc; use std::{error::Error, fmt::Display, io, time::SystemTime}; /// A message from a client to a server. @@ -384,27 +384,11 @@ pub struct ServerError { pub detail: String, } -/// Critical errors that result in a Channel disconnecting. -#[derive(thiserror::Error, Debug, PartialEq, Eq)] -pub enum ChannelError -where - E: Error + Send + Sync + 'static, -{ - /// Could not read from the transport. - #[error("could not read from the transport")] - Read(#[source] Arc), - /// Could not ready the transport for writes. - #[error("could not ready the transport for writes")] - Ready(#[source] E), - /// Could not write to the transport. - #[error("could not write to the transport")] - Write(#[source] E), - /// Could not flush the transport. - #[error("could not flush the transport")] - Flush(#[source] E), - /// Could not close the write end of the transport. - #[error("could not close the write end of the transport")] - Close(#[source] E), +impl ServerError { + /// Returns a new server error with `kind` and `detail`. + pub fn new(kind: io::ErrorKind, detail: String) -> ServerError { + Self { kind, detail } + } } impl Request { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index a06e8f8a4..7b1e49848 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, - trace, ChannelError, ClientMessage, Request, Response, Transport, + trace, ClientMessage, Request, Response, Transport, }; use ::tokio::sync::mpsc; use futures::{ @@ -337,6 +337,20 @@ where } } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError +where + E: Error + Send + Sync + 'static, +{ + /// An error occurred reading from, or writing to, the transport. + #[error("an error occurred in the transport: {0}")] + Transport(#[source] E), + /// An error occurred while polling expired requests. + #[error("an error occurred while polling expired requests: {0}")] + Timer(#[source] ::tokio::time::error::Error), +} + impl Stream for BaseChannel where T: Transport, ClientMessage>, @@ -393,7 +407,7 @@ where let request_status = match self .transport_pin_mut() .poll_next(cx) - .map_err(|e| ChannelError::Read(Arc::new(e)))? + .map_err(ChannelError::Transport)? { Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { @@ -485,7 +499,7 @@ where self.project() .transport .poll_close(cx) - .map_err(ChannelError::Close) + .map_err(ChannelError::Transport) } } @@ -686,6 +700,29 @@ impl InFlightRequest { &self.request } + /// Respond without executing a service function. Useful for early aborts (e.g. for throttling). + pub async fn respond(self, response: Result) { + let Self { + response_tx, + response_guard, + request: Request { id: request_id, .. }, + span, + .. + } = self; + let _entered = span.enter(); + tracing::info!("CompleteRequest"); + let response = Response { + request_id, + message: response, + }; + let _ = response_tx.send(response).await; + tracing::info!("BufferResponse"); + // Request processing has completed, meaning either the channel canceled the request or + // a request was sent back to the channel. Either way, the channel will clean up the + // request data, so the request does not need to be canceled. + mem::forget(response_guard); + } + /// Returns a [future](Future) that executes the request using the given [service /// function](Serve). The service function's output is automatically sent back to the [Channel] /// that yielded this request. The request will be executed in the scope of this request's From 2fb75d92a7a22ea3d14319920a6ef82d3028e6f0 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Mon, 7 Nov 2022 18:37:58 -0800 Subject: [PATCH 15/30] Add request hooks to the Serve trait. This allows plugging in horizontal functionality, such as authorization, throttling, or latency recording, that should run before and/or after execution of every request, regardless of the request type. The tracing example is updated to show off both client stubs as well as server hooks. As part of this change, there were some changes to the Serve trait: 1. Serve's output type is now a Result.. Serve previously did not allow returning ServerErrors, which prevented using Serve for horizontal functionality like throttling or auth. Now, Serve's output type is Result, making Serve a more natural integration point for horizontal capabilities. 2. Serve's generic Request type changed to an associated type. The primary benefit of the generic type is that it allows one type to impl a trait multiple times (for example, u64 impls TryFrom, TryFrom ServiceGenerator<'a> { } = self; quote! { - impl tarpc::server::Serve<#request_ident> for #server_ident + impl tarpc::server::Serve for #server_ident where S: #service_ident { + type Req = #request_ident; type Resp = #response_ident; type Fut = #response_fut_ident; @@ -670,10 +671,10 @@ impl<'a> ServiceGenerator<'a> { quote! { impl std::future::Future for #response_fut_ident { - type Output = #response_ident; + type Output = Result<#response_ident, tarpc::ServerError>; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll<#response_ident> + -> std::task::Poll> { unsafe { match std::pin::Pin::get_unchecked_mut(self) { @@ -682,6 +683,7 @@ impl<'a> ServiceGenerator<'a> { std::pin::Pin::new_unchecked(resp) .poll(cx) .map(#response_ident::#camel_case_idents) + .map(Ok), )* } } diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 27561468e..589c16ffd 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -4,13 +4,32 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::{add::Add as AddService, double::Double as DoubleService}; +#![feature(type_alias_impl_trait)] + +use crate::{ + add::{Add as AddService, AddStub}, + double::Double as DoubleService, +}; use futures::{future, prelude::*}; +use std::{ + io, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; use tarpc::{ - client, context, - server::{incoming::Incoming, BaseChannel}, + client::{ + self, + stub::{load_balance, retry}, + RpcError, + }, + context, serde_transport, + server::{incoming::Incoming, BaseChannel, Serve}, tokio_serde::formats::Json, + ClientMessage, Response, ServerError, Transport, }; +use tokio::net::TcpStream; use tracing_subscriber::prelude::*; pub mod add { @@ -40,12 +59,16 @@ impl AddService for AddServer { } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, } #[tarpc::server] -impl DoubleService for DoubleServer { +impl DoubleService for DoubleServer +where + Stub: AddStub + Clone + Send + Sync + 'static, + for<'a> Stub::RespFut<'a>: Send, +{ async fn double(self, _: context::Context, x: i32) -> Result { self.add_client .add(context::current(), x, x) @@ -70,22 +93,79 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> { Ok(()) } +async fn listen_on_random_port() -> anyhow::Result<( + impl Stream>>, + std::net::SocketAddr, +)> +where + Item: for<'de> serde::Deserialize<'de>, + SinkItem: serde::Serialize, +{ + let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) + .await? + .filter_map(|r| future::ready(r.ok())) + .take(1); + let addr = listener.get_ref().get_ref().local_addr(); + Ok((listener, addr)) +} + +fn make_stub( + backends: [impl Transport>, Response> + Send + Sync + 'static; N], +) -> retry::Retry< + impl Fn(&Result, u32) -> bool + Clone, + load_balance::RoundRobin, Resp>>, +> +where + Req: Send + Sync + 'static, + Resp: Send + Sync + 'static, +{ + let stub = load_balance::RoundRobin::new( + backends + .into_iter() + .map(|transport| tarpc::client::new(client::Config::default(), transport).spawn()) + .collect(), + ); + let stub = retry::Retry::new(stub, |resp, attempts| { + if let Err(e) = resp { + tracing::warn!("Got an error: {e:?}"); + attempts < 3 + } else { + false + } + }); + stub +} + #[tokio::main] async fn main() -> anyhow::Result<()> { init_tracing("tarpc_tracing_example")?; - let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) - .await? - .filter_map(|r| future::ready(r.ok())); - let addr = add_listener.get_ref().local_addr(); - let add_server = add_listener + let (add_listener1, addr1) = listen_on_random_port().await?; + let (add_listener2, addr2) = listen_on_random_port().await?; + let something_bad_happened = Arc::new(AtomicBool::new(false)); + let server = AddServer.serve().before(move |_: &mut _, _: &_| { + let something_bad_happened = something_bad_happened.clone(); + async move { + if something_bad_happened.fetch_xor(true, Ordering::Relaxed) { + Err(ServerError::new( + io::ErrorKind::NotFound, + "Gamma Ray!".into(), + )) + } else { + Ok(()) + } + } + }); + let add_server = add_listener1 + .chain(add_listener2) .map(BaseChannel::with_defaults) - .take(1) - .execute(AddServer.serve()); + .execute(server); tokio::spawn(add_server); - let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn(); + let add_client = add::AddClient::from(make_stub([ + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, + ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 7b1e49848..d5e5b5259 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, - trace, ClientMessage, Request, Response, Transport, + trace, ClientMessage, Request, Response, ServerError, Transport, }; use ::tokio::sync::mpsc; use futures::{ @@ -25,6 +25,7 @@ use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sy use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; +pub mod request_hook; #[cfg(test)] mod testing; @@ -39,6 +40,10 @@ pub mod incoming; #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] pub mod tokio; +use request_hook::{ + AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, +}; + /// Settings that control the behavior of [channels](Channel). #[derive(Clone, Debug)] pub struct Config { @@ -67,32 +72,212 @@ impl Config { } /// Equivalent to a `FnOnce(Req) -> impl Future`. -pub trait Serve { +pub trait Serve { + /// Type of request. + type Req; + /// Type of response. type Resp; /// Type of response future. - type Fut: Future; + type Fut: Future>; + + /// Responds to a single request. + fn serve(self, ctx: context::Context, req: Self::Req) -> Self::Fut; /// Extracts a method name from the request. - fn method(&self, _request: &Req) -> Option<&'static str> { + fn method(&self, _request: &Self::Req) -> Option<&'static str> { None } - /// Responds to a single request. - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut; + /// Runs a hook before execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context. This could be used, for example, to enforce a + /// maximum deadline on all requests. + /// + /// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &RequestType) -> impl Future>` can + /// also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// .before(|_ctx: &mut context::Context, req: &i32| { + /// future::ready( + /// if *req == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("I don't like {req}"))) + /// } else { + /// Ok(()) + /// }) + /// }); + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn before(self, hook: Hook) -> BeforeRequestHook + where + Hook: BeforeRequest, + Self: Sized, + { + BeforeRequestHook::new(self, hook) + } + + /// Runs a hook after completion of a request. + /// + /// The hook can modify the request context and the response. + /// + /// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &mut Result) -> impl Future` + /// can also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve( + /// |_ctx, i| async move { + /// if i == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("{i} is the loneliest number"))) + /// } else { + /// Ok(i + 1) + /// } + /// }) + /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// if let Err(e) = resp { + /// eprintln!("server error: {e:?}"); + /// } + /// future::ready(()) + /// }); + /// + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn after(self, hook: Hook) -> AfterRequestHook + where + Hook: AfterRequest, + Self: Sized, + { + AfterRequestHook::new(self, hook) + } + + /// Runs a hook before and after execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context and the response. This could be used, for + /// example, to enforce a maximum deadline on all requests. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{ + /// context, ServerError, server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest}} + /// }; + /// use std::{io, time::Instant}; + /// + /// struct PrintLatency(Instant); + /// + /// impl BeforeRequest for PrintLatency { + /// type Fut<'a> = future::Ready> where Self: 'a, Req: 'a; + /// + /// fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + /// self.0 = Instant::now(); + /// future::ready(Ok(())) + /// } + /// } + /// + /// impl AfterRequest for PrintLatency { + /// type Fut<'a> = future::Ready<()> where Self:'a, Resp:'a; + /// + /// fn after<'a>( + /// &'a mut self, + /// _: &'a mut context::Context, + /// _: &'a mut Result, + /// ) -> Self::Fut<'a> { + /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); + /// future::ready(()) + /// } + /// } + /// + /// let serve = serve(|_ctx, i| async move { + /// Ok(i + 1) + /// }).before_and_after(PrintLatency(Instant::now())); + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_ok()); + /// ``` + fn before_and_after( + self, + hook: Hook, + ) -> BeforeAndAfterRequestHook + where + Hook: BeforeRequest + AfterRequest, + Self: Sized, + { + BeforeAndAfterRequestHook::new(self, hook) + } } -impl Serve for F +/// A Serve wrapper around a Fn. +#[derive(Debug)] +pub struct ServeFn { + f: F, + data: PhantomData Resp>, +} + +impl Clone for ServeFn +where + F: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + data: PhantomData, + } + } +} + +impl Copy for ServeFn where F: Copy {} + +/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. +pub fn serve(f: F) -> ServeFn where F: FnOnce(context::Context, Req) -> Fut, - Fut: Future, + Fut: Future>, { + ServeFn { + f, + data: PhantomData, + } +} + +impl Serve for ServeFn +where + F: FnOnce(context::Context, Req) -> Fut, + Fut: Future>, +{ + type Req = Req; type Resp = Resp; type Fut = Fut; fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { - self(ctx, req) + (self.f)(ctx, req) } } @@ -307,6 +492,34 @@ where /// This is a terminal operation. After calling `requests`, the channel cannot be retrieved, /// and the only way to complete requests is via [`Requests::execute`] or /// [`InFlightRequest::execute`]. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// let mut requests = server.requests(); + /// tokio::spawn(async move { + /// while let Some(Ok(request)) = requests.next().await { + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// } + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` fn requests(self) -> Requests where Self: Sized, @@ -323,12 +536,28 @@ where /// Runs the channel until completion by executing all requests using the given service /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's /// default executor. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// let channel = BaseChannel::new(server::Config::default(), rx); + /// tokio::spawn(channel.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> where Self: Sized, - S: Serve + Send + 'static, + S: Serve + Send + 'static, S::Fut: Send, Self::Req: Send + 'static, Self::Resp: Send + 'static, @@ -737,13 +966,43 @@ impl InFlightRequest { /// /// If the returned Future is dropped before completion, a cancellation message will be sent to /// the Channel to clean up associated request state. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// tokio::spawn(async move { + /// let mut requests = server.requests(); + /// while let Some(Ok(in_flight_request)) = requests.next().await { + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// } + /// + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + /// pub async fn execute(self, serve: S) where - S: Serve, + S: Serve, { let Self { response_tx, - mut response_guard, + response_guard, abort_registration, span, request: @@ -760,12 +1019,11 @@ impl InFlightRequest { span.record("otel.name", &method.unwrap_or("")); let _ = Abortable::new( async move { - tracing::info!("BeginRequest"); - let response = serve.serve(context, message).await; + let message = serve.serve(context, message).await; tracing::info!("CompleteRequest"); let response = Response { request_id, - message: Ok(response), + message, }; let _ = response_tx.send(response).await; tracing::info!("BufferResponse"); @@ -809,11 +1067,14 @@ where #[cfg(test)] mod tests { - use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests}; + use super::{ + in_flight_requests::AlreadyExistsError, serve, AfterRequest, BaseChannel, BeforeRequest, + Channel, Config, Requests, Serve, + }; use crate::{ context, trace, transport::channel::{self, UnboundedChannel}, - ClientMessage, Request, Response, + ClientMessage, Request, Response, ServerError, }; use assert_matches::assert_matches; use futures::{ @@ -822,7 +1083,12 @@ mod tests { Future, }; use futures_test::task::noop_context; - use std::{pin::Pin, task::Poll}; + use std::{ + io, + pin::Pin, + task::Poll, + time::{Duration, Instant, SystemTime}, + }; fn test_channel() -> ( Pin, Response>>>>, @@ -883,6 +1149,101 @@ mod tests { Abortable::new(pending(), abort_registration) } + #[tokio::test] + async fn test_serve() { + let serve = serve(|_, i| async move { Ok(i) }); + assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + } + + #[tokio::test] + async fn serve_before_mutates_context() -> anyhow::Result<()> { + struct SetDeadline(SystemTime); + type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; + impl BeforeRequest for SetDeadline { + type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>( + &'a mut self, + ctx: &'a mut context::Context, + _: &'a Req, + ) -> Self::Fut<'a> { + async move { + ctx.deadline = self.0; + Ok(()) + } + } + } + + let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); + let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); + + let serve = serve(move |ctx: context::Context, i| async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + }); + let deadline_hook = serve.before(SetDeadline(some_time)); + let mut ctx = context::current(); + ctx.deadline = some_other_time; + deadline_hook.serve(ctx, 7).await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_and_after() -> anyhow::Result<()> { + let _ = tracing_subscriber::fmt::try_init(); + + struct PrintLatency { + start: Instant, + } + impl PrintLatency { + fn new() -> Self { + Self { + start: Instant::now(), + } + } + } + type StartFut<'a, Req: 'a> = impl Future> + 'a; + type EndFut<'a, Resp: 'a> = impl Future + 'a; + impl BeforeRequest for PrintLatency { + type Fut<'a> = StartFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + async move { + self.start = Instant::now(); + Ok(()) + } + } + } + impl AfterRequest for PrintLatency { + type Fut<'a> = EndFut<'a, Resp> where Self: 'a, Resp: 'a; + fn after<'a>( + &'a mut self, + _: &'a mut context::Context, + _: &'a mut Result, + ) -> Self::Fut<'a> { + async move { + tracing::info!("Elapsed: {:?}", self.start.elapsed()); + } + } + } + + let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + serve + .before_and_after(PrintLatency::new()) + .serve(context::current(), 7) + .await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_error_aborts_request() -> anyhow::Result<()> { + let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + Err(ServerError::new(io::ErrorKind::Other, "oops".into())) + }); + let resp: Result = deadline_hook.serve(context::current(), 7).await; + assert_matches!(resp, Err(_)); + Ok(()) + } + #[tokio::test] async fn base_channel_start_send_duplicate_request_returns_error() { let (mut channel, _tx) = test_channel::<(), ()>(); @@ -1083,7 +1444,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {:?}", result), }; - request.execute(|_, _| async {}).await; + request.execute(serve(|_, _| async { Ok(()) })).await; assert!(requests .as_mut() .channel_pin_mut() diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 445fc3e89..931e87669 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -35,7 +35,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] fn execute(self, serve: S) -> TokioServerExecutor where - S: Serve, + S: Serve, { TokioServerExecutor::new(self, serve) } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs new file mode 100644 index 000000000..ef23d73b4 --- /dev/null +++ b/tarpc/src/server/request_hook.rs @@ -0,0 +1,22 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Hooks for horizontal functionality that can run either before or after a request is executed. + +/// A request hook that runs before a request is executed. +mod before; + +/// A request hook that runs after a request is completed. +mod after; + +/// A request hook that runs both before a request is executed and after it is completed. +mod before_and_after; + +pub use { + after::{AfterRequest, AfterRequestHook}, + before::{BeforeRequest, BeforeRequestHook}, + before_and_after::BeforeAndAfterRequestHook, +}; diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs new file mode 100644 index 000000000..a3803bade --- /dev/null +++ b/tarpc/src/server/request_hook/after.rs @@ -0,0 +1,89 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs after request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs after request execution. +pub trait AfterRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future + where + Self: 'a, + Resp: 'a; + + /// The function that is called after request execution. + /// + /// The hook can modify the request context and the response. + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a>; +} + +impl AfterRequest for F +where + F: FnMut(&mut context::Context, &mut Result) -> Fut, + Fut: Future, +{ + type Fut<'a> = Fut where Self: 'a, Resp: 'a; + + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a> { + self(ctx, resp) + } +} + +/// A Service function that runs a hook after request execution. +pub struct AfterRequestHook { + serve: Serv, + hook: Hook, +} + +impl AfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for AfterRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for AfterRequestHook +where + Serv: Serve, + Hook: AfterRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + type Fut = AfterRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Serv::Req) -> Self::Fut { + async move { + let AfterRequestHook { + serve, mut hook, .. + } = self; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp + } + } +} + +type AfterRequestHookFut> = + impl Future>; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs new file mode 100644 index 000000000..38ad54d01 --- /dev/null +++ b/tarpc/src/server/request_hook/before.rs @@ -0,0 +1,84 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs before request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs before request execution. +pub trait BeforeRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future> + where + Self: 'a, + Req: 'a; + + /// The function that is called before request execution. + /// + /// If this function returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// This function can also modify the request context. This could be used, for example, to + /// enforce a maximum deadline on all requests. + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a>; +} + +impl BeforeRequest for F +where + F: FnMut(&mut context::Context, &Req) -> Fut, + Fut: Future>, +{ + type Fut<'a> = Fut where Self: 'a, Req: 'a; + + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a> { + self(ctx, req) + } +} + +/// A Service function that runs a hook before request execution. +pub struct BeforeRequestHook { + serve: Serv, + hook: Hook, +} + +impl BeforeRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for BeforeRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for BeforeRequestHook +where + Serv: Serve, + Hook: BeforeRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + type Fut = BeforeRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Self::Req) -> Self::Fut { + let BeforeRequestHook { + serve, mut hook, .. + } = self; + async move { + hook.before(&mut ctx, &req).await?; + serve.serve(ctx, req).await + } + } +} + +type BeforeRequestHookFut> = + impl Future>; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs new file mode 100644 index 000000000..ca42460bc --- /dev/null +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -0,0 +1,70 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs both before and after request execution. + +use super::{after::AfterRequest, before::BeforeRequest}; +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; +use std::marker::PhantomData; + +/// A Service function that runs a hook both before and after request execution. +pub struct BeforeAndAfterRequestHook { + serve: Serv, + hook: Hook, + fns: PhantomData<(fn(Req), fn(Resp))>, +} + +impl BeforeAndAfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { + serve, + hook, + fns: PhantomData, + } + } +} + +impl Clone + for BeforeAndAfterRequestHook +{ + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + fns: PhantomData, + } + } +} + +impl Serve for BeforeAndAfterRequestHook +where + Serv: Serve, + Hook: BeforeRequest + AfterRequest, +{ + type Req = Req; + type Resp = Resp; + type Fut = BeforeAndAfterRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Req) -> Self::Fut { + async move { + let BeforeAndAfterRequestHook { + serve, mut hook, .. + } = self; + hook.before(&mut ctx, &req).await?; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp + } + } +} + +type BeforeAndAfterRequestHookFut< + Req, + Resp, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, +> = impl Future>; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs index a44e8469e..e9ad84221 100644 --- a/tarpc/src/server/tokio.rs +++ b/tarpc/src/server/tokio.rs @@ -55,9 +55,25 @@ where { /// Executes all requests using the given service function. Requests are handled concurrently /// by [spawning](::tokio::spawn) each handler on tokio's default executor. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// tokio::spawn(requests.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` pub fn execute(self, serve: S) -> TokioChannelExecutor where - S: Serve + Send + 'static, + S: Serve + Send + 'static, { TokioChannelExecutor { inner: self, serve } } @@ -69,7 +85,7 @@ where C: Channel + Send + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, + Se: Serve + Send + 'static + Clone, Se::Fut: Send, { type Output = (); @@ -88,7 +104,7 @@ where C: Channel + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, + S: Serve + Send + 'static + Clone, S::Fut: Send, { type Output = (); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 529ae8f58..7f3035d14 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -150,12 +150,14 @@ impl Sink for Channel { #[cfg(feature = "tokio1")] mod tests { use crate::{ - client, context, - server::{incoming::Incoming, BaseChannel}, + client::{self, RpcError}, + context, + server::{incoming::Incoming, serve, BaseChannel}, transport::{ self, channel::{Channel, UnboundedChannel}, }, + ServerError, }; use assert_matches::assert_matches; use futures::{prelude::*, stream}; @@ -177,25 +179,25 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(|_ctx, request: String| { - future::ready(request.parse::().map_err(|_| { - io::Error::new( + .execute(serve(|_ctx, request: String| async move { + request.parse::().map_err(|_| { + ServerError::new( io::ErrorKind::InvalidInput, format!("{request:?} is not an int"), ) - })) - }), + }) + })), ); let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "", "123".into()).await?; - let response2 = client.call(context::current(), "", "abc".into()).await?; + let response1 = client.call(context::current(), "", "123".into()).await; + let response2 = client.call(context::current(), "", "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); assert_matches!(response1, Ok(123)); - assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput); + assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput); Ok(()) } From 4e2a4993b561369200f1cf139d9ef1da5cc662bd Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sat, 12 Nov 2022 16:32:27 -0800 Subject: [PATCH 16/30] Use rust nightly for Github workflows. While using unstable feature type_alias_impl_trait. --- hooks/pre-push | 10 +++++----- .../tarpc_server_missing_async.stderr | 16 ++++++++++------ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/hooks/pre-push b/hooks/pre-push index 7b527e0a8..1e5500d63 100755 --- a/hooks/pre-push +++ b/hooks/pre-push @@ -84,12 +84,12 @@ command -v rustup &>/dev/null if [ "$?" == 0 ]; then printf "${SUCCESS}\n" - try_run "Building ... " cargo +stable build --color=always - try_run "Testing ... " cargo +stable test --color=always - try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always - for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}') + try_run "Building ... " cargo build --color=always + try_run "Testing ... " cargo test --color=always + try_run "Testing with all features enabled ... " cargo test --all-features --color=always + for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}') do - try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE + try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE done check_toolchain nightly diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr index 28106e63f..d96cda833 100644 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr @@ -1,11 +1,15 @@ error: not all trait items implemented, missing: `HelloFut` - --> $DIR/tarpc_server_missing_async.rs:9:1 - | -9 | impl World for HelloServer { - | ^^^^ + --> tests/compile_fail/tarpc_server_missing_async.rs:9:1 + | +9 | / impl World for HelloServer { +10 | | fn hello(name: String) -> String { +11 | | format!("Hello, {name}!", name) +12 | | } +13 | | } + | |_^ error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> $DIR/tarpc_server_missing_async.rs:10:5 + --> tests/compile_fail/tarpc_server_missing_async.rs:10:5 | 10 | fn hello(name: String) -> String { - | ^^ + | ^^^^^^^^ From e703da5a4ec5d4cc8871601ceeed5b3eca8a869f Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 23 Nov 2022 01:36:51 -0800 Subject: [PATCH 17/30] Use async fn in generated traits!! Add helper fn to server::incoming module for spawning. --- example-service/src/lib.rs | 3 + example-service/src/server.rs | 10 +- plugins/src/lib.rs | 235 ++---------------- plugins/tests/server.rs | 101 +------- plugins/tests/service.rs | 42 ++-- tarpc/Cargo.toml | 3 +- tarpc/examples/compression.rs | 17 +- tarpc/examples/custom_transport.rs | 18 +- tarpc/examples/pubsub.rs | 13 +- tarpc/examples/readme.rs | 20 +- tarpc/examples/tracing.rs | 23 +- tarpc/src/lib.rs | 97 ++------ tarpc/src/server.rs | 145 ++++++++--- tarpc/src/server/incoming.rs | 67 ++++- tarpc/src/server/request_hook/after.rs | 24 +- tarpc/src/server/request_hook/before.rs | 16 +- .../server/request_hook/before_and_after.rs | 27 +- tarpc/src/server/tokio.rs | 129 ---------- tarpc/src/transport/channel.rs | 33 ++- tarpc/tests/compile_fail.rs | 2 - .../compile_fail/must_use_request_dispatch.rs | 3 + .../must_use_request_dispatch.stderr | 8 +- .../tarpc_server_missing_async.rs | 15 -- .../tarpc_server_missing_async.stderr | 15 -- .../tokio/must_use_channel_executor.rs | 29 --- .../tokio/must_use_channel_executor.stderr | 11 - .../tokio/must_use_server_executor.rs | 30 --- .../tokio/must_use_server_executor.stderr | 11 - tarpc/tests/dataservice.rs | 13 +- tarpc/tests/service_functional.rs | 81 +++--- 30 files changed, 412 insertions(+), 829 deletions(-) delete mode 100644 tarpc/src/server/tokio.rs delete mode 100644 tarpc/tests/compile_fail/tarpc_server_missing_async.rs delete mode 100644 tarpc/tests/compile_fail/tarpc_server_missing_async.stderr delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_server_executor.rs delete mode 100644 tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index bc38fe93e..822d8217b 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use std::env; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index b0281e983..6c78598be 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use clap::Parser; use futures::{future, prelude::*}; use rand::{ @@ -34,7 +37,6 @@ struct Flags { #[derive(Clone)] struct HelloServer(SocketAddr); -#[tarpc::server] impl World for HelloServer { async fn hello(self, _: context::Context, name: String) -> String { let sleep_time = @@ -44,6 +46,10 @@ impl World for HelloServer { } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let flags = Flags::parse(); @@ -66,7 +72,7 @@ async fn main() -> anyhow::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.transport().peer_addr().unwrap()); - channel.execute(server.serve()) + channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. .buffer_unordered(10) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index efab161bb..f33cea09e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -12,18 +12,18 @@ extern crate quote; extern crate syn; use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; +use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ braced, ext::IdentExt, parenthesized, parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, parse_str, + parse_macro_input, parse_quote, spanned::Spanned, token::Comma, - Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool, - MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility, + Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, + Visibility, }; /// Accumulates multiple errors into a result. @@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string())) .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); - let response_fut_name = &format!("{}ResponseFut", ident.unraw()); let derive_serialize = if derive_serde.0 { Some( quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] @@ -274,11 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .collect::>(); ServiceGenerator { - response_fut_name, service_ident: ident, client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), - response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), request_ident: &format_ident!("{}Request", ident), response_ident: &format_ident!("{}Response", ident), @@ -305,138 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .zip(camel_case_fn_names.iter()) .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), - future_types: &camel_case_fn_names - .iter() - .map(|name| parse_str(&format!("{name}Fut")).unwrap()) - .collect::>(), derive_serialize: derive_serialize.as_ref(), } .into_token_stream() .into() } -/// generate an identifier consisting of the method name to CamelCase with -/// Fut appended to it. -fn associated_type_for_rpc(method: &ImplItemMethod) -> String { - snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut" -} - -/// Transforms an async function into a sync one, returning a type declaration -/// for the return type (a future). -fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { - method.sig.asyncness = None; - - // get either the return type or (). - let ret = match &method.sig.output { - ReturnType::Default => quote!(()), - ReturnType::Type(_, ret) => quote!(#ret), - }; - - let fut_name = associated_type_for_rpc(method); - let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span()); - - // generate the updated return signature. - method.sig.output = parse_quote! { - -> ::core::pin::Pin + ::core::marker::Send - >> - }; - - // transform the body of the method into Box::pin(async move { body }). - let block = method.block.clone(); - method.block = parse_quote! [{ - Box::pin(async move - #block - ) - }]; - - // generate and return type declaration for return type. - let t: ImplItemType = parse_quote! { - type #fut_name_ident = ::core::pin::Pin + ::core::marker::Send>>; - }; - - t -} - -#[proc_macro_attribute] -pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream { - let mut item = syn::parse_macro_input!(input as ItemImpl); - let span = item.span(); - - // the generated type declarations - let mut types: Vec = Vec::new(); - let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new(); - let mut found_non_async_types: Vec<&ImplItemType> = Vec::new(); - - for inner in &mut item.items { - match inner { - ImplItem::Method(method) => { - if method.sig.asyncness.is_some() { - // if this function is declared async, transform it into a regular function - let typedecl = transform_method(method); - types.push(typedecl); - } else { - // If it's not async, keep track of all required associated types for better - // error reporting. - expected_non_async_types.push((method, associated_type_for_rpc(method))); - } - } - ImplItem::Type(typedecl) => found_non_async_types.push(typedecl), - _ => {} - } - } - - if let Err(e) = - verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types) - { - return TokenStream::from(e.to_compile_error()); - } - - // add the type declarations into the impl block - for t in types.into_iter() { - item.items.push(syn::ImplItem::Type(t)); - } - - TokenStream::from(quote!(#item)) -} - -fn verify_types_were_provided( - span: Span, - expected: &[(&ImplItemMethod, String)], - provided: &[&ImplItemType], -) -> syn::Result<()> { - let mut result = Ok(()); - for (method, expected) in expected { - if !provided.iter().any(|typedecl| typedecl.ident == expected) { - let mut e = syn::Error::new( - span, - format!("not all trait items implemented, missing: `{expected}`"), - ); - let fn_span = method.sig.fn_token.span(); - e.extend(syn::Error::new( - fn_span.join(method.sig.ident.span()).unwrap_or(fn_span), - format!( - "hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async", - method.sig.ident - ), - )); - match result { - Ok(_) => result = Err(e), - Err(ref mut error) => error.extend(Some(e)), - } - } - } - result -} - // Things needed to generate the service items: trait, serve impl, request/response enums, and // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, client_stub_ident: &'a Ident, server_ident: &'a Ident, - response_fut_ident: &'a Ident, - response_fut_name: &'a str, client_ident: &'a Ident, request_ident: &'a Ident, response_ident: &'a Ident, @@ -444,7 +321,6 @@ struct ServiceGenerator<'a> { attrs: &'a [Attribute], rpcs: &'a [RpcMethod], camel_case_idents: &'a [Ident], - future_types: &'a [Type], method_idents: &'a [&'a Ident], request_names: &'a [String], method_attrs: &'a [&'a [Attribute]], @@ -460,7 +336,6 @@ impl<'a> ServiceGenerator<'a> { attrs, rpcs, vis, - future_types, return_types, service_ident, client_stub_ident, @@ -470,27 +345,19 @@ impl<'a> ServiceGenerator<'a> { .. } = self; - let types_and_fns = rpcs + let rpc_fns = rpcs .iter() - .zip(future_types.iter()) .zip(return_types.iter()) .map( |( - ( - RpcMethod { - attrs, ident, args, .. - }, - future_type, - ), + RpcMethod { + attrs, ident, args, .. + }, output, )| { - let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`]."); quote! { - #[doc = #ty_doc] - type #future_type: std::future::Future; - #( #attrs )* - fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; + async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output; } }, ); @@ -499,7 +366,7 @@ impl<'a> ServiceGenerator<'a> { quote! { #( #attrs )* #vis trait #service_ident: Sized { - #( #types_and_fns )* + #( #rpc_fns )* /// Returns a serving function to use with /// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute). @@ -539,7 +406,6 @@ impl<'a> ServiceGenerator<'a> { server_ident, service_ident, response_ident, - response_fut_ident, camel_case_idents, arg_pats, method_idents, @@ -553,7 +419,6 @@ impl<'a> ServiceGenerator<'a> { { type Req = #request_ident; type Resp = #response_ident; - type Fut = #response_fut_ident; fn method(&self, req: &#request_ident) -> Option<&'static str> { Some(match req { @@ -565,15 +430,16 @@ impl<'a> ServiceGenerator<'a> { }) } - fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { + async fn serve(self, ctx: tarpc::context::Context, req: #request_ident) + -> Result<#response_ident, tarpc::ServerError> { match req { #( #request_ident::#camel_case_idents{ #( #arg_pats ),* } => { - #response_fut_ident::#camel_case_idents( + Ok(#response_ident::#camel_case_idents( #service_ident::#method_idents( self.service, ctx, #( #arg_pats ),* - ) - ) + ).await + )) } )* } @@ -624,74 +490,6 @@ impl<'a> ServiceGenerator<'a> { } } - fn enum_response_future(&self) -> TokenStream2 { - let &Self { - vis, - service_ident, - response_fut_ident, - camel_case_idents, - future_types, - .. - } = self; - - quote! { - /// A future resolving to a server response. - #[allow(missing_docs)] - #vis enum #response_fut_ident { - #( #camel_case_idents(::#future_types) ),* - } - } - } - - fn impl_debug_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_fut_name, - .. - } = self; - - quote! { - impl std::fmt::Debug for #response_fut_ident { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct(#response_fut_name).finish() - } - } - } - } - - fn impl_future_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_ident, - camel_case_idents, - .. - } = self; - - quote! { - impl std::future::Future for #response_fut_ident { - type Output = Result<#response_ident, tarpc::ServerError>; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll> - { - unsafe { - match std::pin::Pin::get_unchecked_mut(self) { - #( - #response_fut_ident::#camel_case_idents(resp) => - std::pin::Pin::new_unchecked(resp) - .poll(cx) - .map(#response_ident::#camel_case_idents) - .map(Ok), - )* - } - } - } - } - } - } - fn struct_client(&self) -> TokenStream2 { let &Self { vis, @@ -804,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.impl_serve_for_server(), self.enum_request(), self.enum_response(), - self.enum_response_future(), - self.impl_debug_for_response_future(), - self.impl_future_for_response_future(), self.struct_client(), self.impl_client_new(), self.impl_client_rpc_methods(), diff --git a/plugins/tests/server.rs b/plugins/tests/server.rs index f0222ffd3..7fcec793e 100644 --- a/plugins/tests/server.rs +++ b/plugins/tests/server.rs @@ -1,7 +1,5 @@ -use assert_type_eq::assert_type_eq; -use futures::Future; -use std::pin::Pin; -use tarpc::context; +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] // these need to be out here rather than inside the function so that the // assert_type_eq macro can pick them up. @@ -12,42 +10,6 @@ trait Foo { async fn baz(); } -#[test] -fn type_generation_works() { - #[tarpc::server] - impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { - (s, i) - } - - async fn bar(self, _: context::Context, s: String) -> String { - s - } - - async fn baz(self, _: context::Context) {} - } - - // the assert_type_eq macro can only be used once per block. - { - assert_type_eq!( - <() as Foo>::TwoPartFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BarFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BazFut, - Pin + Send>> - ); - } -} - #[allow(non_camel_case_types)] #[test] fn raw_idents_work() { @@ -59,24 +21,6 @@ fn raw_idents_work() { async fn r#fn(r#impl: r#yield) -> r#yield; async fn r#async(); } - - #[tarpc::server] - impl r#trait for () { - async fn r#await( - self, - _: context::Context, - r#struct: r#yield, - r#enum: i32, - ) -> (r#yield, i32) { - (r#struct, r#enum) - } - - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { - r#impl - } - - async fn r#async(self, _: context::Context) {} - } } #[test] @@ -100,45 +44,4 @@ fn syntax() { #[doc = "attr"] async fn one_arg_implicit_return_error(one: String); } - - #[tarpc::server] - impl Syntax for () { - #[deny(warnings)] - #[allow(non_snake_case)] - async fn TestCamelCaseDoesntConflict(self, _: context::Context) {} - - async fn hello(self, _: context::Context) -> String { - String::new() - } - - async fn attr(self, _: context::Context, _s: String) -> String { - String::new() - } - - async fn no_args_no_return(self, _: context::Context) {} - - async fn no_args(self, _: context::Context) -> () {} - - async fn one_arg(self, _: context::Context, _one: String) -> i32 { - 0 - } - - async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {} - - async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String { - String::new() - } - - async fn no_args_ret_error(self, _: context::Context) -> i32 { - 0 - } - - async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String { - String::new() - } - - async fn no_arg_implicit_return_error(self, _: context::Context) {} - - async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {} - } } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index b37cbcead..38bd7f0dc 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,9 +1,10 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use tarpc::context; #[test] fn att_service_trait() { - use futures::future::{ready, Ready}; - #[tarpc::service] trait Foo { async fn two_part(s: String, i: i32) -> (String, i32); @@ -12,19 +13,16 @@ fn att_service_trait() { } impl Foo for () { - type TwoPartFut = Ready<(String, i32)>; - fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut { - ready((s, i)) + async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + (s, i) } - type BarFut = Ready; - fn bar(self, _: context::Context, s: String) -> Self::BarFut { - ready(s) + async fn bar(self, _: context::Context, s: String) -> String { + s } - type BazFut = Ready<()>; - fn baz(self, _: context::Context) -> Self::BazFut { - ready(()) + async fn baz(self, _: context::Context) { + () } } } @@ -32,8 +30,6 @@ fn att_service_trait() { #[allow(non_camel_case_types)] #[test] fn raw_idents() { - use futures::future::{ready, Ready}; - type r#yield = String; #[tarpc::service] @@ -44,19 +40,21 @@ fn raw_idents() { } impl r#trait for () { - type AwaitFut = Ready<(r#yield, i32)>; - fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut { - ready((r#struct, r#enum)) + async fn r#await( + self, + _: context::Context, + r#struct: r#yield, + r#enum: i32, + ) -> (r#yield, i32) { + (r#struct, r#enum) } - type FnFut = Ready; - fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut { - ready(r#impl) + async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + r#impl } - type AsyncFut = Ready<()>; - fn r#async(self, _: context::Context) -> Self::AsyncFut { - ready(()) + async fn r#async(self, _: context::Context) { + () } } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index c6f80644e..87808776e 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -75,7 +75,8 @@ opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] } pin-utils = "0.1.0-alpha" serde_bytes = "0.11" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tokio = { version = "1", features = ["full", "test-util"] } +tokio = { version = "1", features = ["full", "test-util", "tracing"] } +console-subscriber = "0.1" tokio-serde = { version = "0.8", features = ["json", "bincode"] } trybuild = "1.0" tokio-rustls = "0.23" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 942fdc8af..cc993f0af 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -1,5 +1,14 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; -use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt}; +use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; @@ -99,13 +108,16 @@ pub trait World { #[derive(Clone, Debug)] struct HelloServer; -#[tarpc::server] impl World for HelloServer { async fn hello(self, _: context::Context, name: String) -> String { format!("Hey, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; @@ -114,6 +126,7 @@ async fn main() -> anyhow::Result<()> { let transport = incoming.next().await.unwrap().unwrap(); BaseChannel::with_defaults(add_compression(transport)) .execute(HelloServer.serve()) + .for_each(spawn) .await; }); diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index e7e2ce3d5..2c5fd4dc4 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -1,3 +1,13 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + +use futures::prelude::*; use tarpc::context::Context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; @@ -13,7 +23,6 @@ pub trait PingService { #[derive(Clone)] struct Service; -#[tarpc::server] impl PingService for Service { async fn ping(self, _: Context) {} } @@ -26,13 +35,18 @@ async fn main() -> anyhow::Result<()> { let listener = UnixListener::bind(bind_addr).unwrap(); let codec_builder = LengthDelimitedCodec::builder(); + async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); + } tokio::spawn(async move { loop { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); + let fut = BaseChannel::with_defaults(transport) + .execute(Service.serve()) + .for_each(spawn); tokio::spawn(fut); } }); diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 910ab535f..5b5b2eedb 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" /// port. Because both publishers and subscribers initiate their connections to the PubSub /// server, the server requires no prior knowledge of either publishers or subscribers. @@ -79,7 +82,6 @@ struct Subscriber { topics: Vec, } -#[tarpc::server] impl subscriber::Subscriber for Subscriber { async fn topics(self, _: context::Context) -> Vec { self.topics.clone() @@ -117,7 +119,8 @@ impl Subscriber { )) } }; - let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); + let (handler, abort_handle) = + future::abortable(handler.execute(subscriber.serve()).for_each(spawn)); tokio::spawn(async move { match handler.await { Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."), @@ -143,6 +146,10 @@ struct PublisherAddrs { subscriptions: SocketAddr, } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + impl Publisher { async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -162,6 +169,7 @@ impl Publisher { server::BaseChannel::with_defaults(publisher) .execute(self.serve()) + .for_each(spawn) .await }); @@ -257,7 +265,6 @@ impl Publisher { } } -#[tarpc::server] impl publisher::Publisher for Publisher { async fn publish(self, _: context::Context, topic: String, message: String) { info!("received message to publish."); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 80792314f..c6ef61eb4 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -4,7 +4,10 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use futures::future::{self, Ready}; +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + +use futures::prelude::*; use tarpc::{ client, context, server::{self, Channel}, @@ -23,22 +26,21 @@ pub trait World { struct HelloServer; impl World for HelloServer { - // Each defined rpc generates two items in the trait, a fn that serves the RPC, and - // an associated type representing the future output by the fn. - - type HelloFut = Ready; - - fn hello(self, _: context::Context, name: String) -> Self::HelloFut { - future::ready(format!("Hello, {name}!")) + async fn hello(self, _: context::Context, name: String) -> String { + format!("Hello, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); - tokio::spawn(server.execute(HelloServer.serve())); + tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // that takes a config and any Transport as input. diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 589c16ffd..d37fbabea 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -4,7 +4,8 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![feature(type_alias_impl_trait)] +#![allow(incomplete_features)] +#![feature(async_fn_in_trait, type_alias_impl_trait)] use crate::{ add::{Add as AddService, AddStub}, @@ -25,7 +26,10 @@ use tarpc::{ RpcError, }, context, serde_transport, - server::{incoming::Incoming, BaseChannel, Serve}, + server::{ + incoming::{spawn_incoming, Incoming}, + BaseChannel, Serve, + }, tokio_serde::formats::Json, ClientMessage, Response, ServerError, Transport, }; @@ -51,7 +55,6 @@ pub mod double { #[derive(Clone)] struct AddServer; -#[tarpc::server] impl AddService for AddServer { async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { x + y @@ -63,7 +66,6 @@ struct DoubleServer { add_client: add::AddClient, } -#[tarpc::server] impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, @@ -158,9 +160,8 @@ async fn main() -> anyhow::Result<()> { }); let add_server = add_listener1 .chain(add_listener2) - .map(BaseChannel::with_defaults) - .execute(server); - tokio::spawn(add_server); + .map(BaseChannel::with_defaults); + tokio::spawn(spawn_incoming(add_server.execute(server))); let add_client = add::AddClient::from(make_stub([ tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, @@ -171,11 +172,9 @@ async fn main() -> anyhow::Result<()> { .await? .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); - let double_server = double_listener - .map(BaseChannel::with_defaults) - .take(1) - .execute(DoubleServer { add_client }.serve()); - tokio::spawn(double_server); + let double_server = double_listener.map(BaseChannel::with_defaults).take(1); + let server = DoubleServer { add_client }.serve(); + tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let double_client = diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 280da694e..391c6fcaf 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -80,6 +80,8 @@ //! First, let's set up the dependencies and service definition. //! //! ```rust +//! #![allow(incomplete_features)] +//! #![feature(async_fn_in_trait)] //! # extern crate futures; //! //! use futures::{ @@ -104,6 +106,8 @@ //! implement it for our Server struct. //! //! ```rust +//! # #![allow(incomplete_features)] +//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -126,13 +130,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! // an associated type representing the future output by the fn. -//! -//! type HelloFut = Ready; -//! -//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! async fn hello(self, _: context::Context, name: String) -> String { +//! format!("Hello, {name}!") //! } //! } //! ``` @@ -143,6 +143,8 @@ //! available behind the `tcp` feature. //! //! ```rust +//! # #![allow(incomplete_features)] +//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -164,11 +166,9 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! # // an associated type representing the future output by the fn. -//! # type HelloFut = Ready; -//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! # future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # format!("Hello, {name}!") //! # } //! # } //! # #[cfg(not(feature = "tokio1"))] @@ -179,7 +179,12 @@ //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! //! let server = server::BaseChannel::with_defaults(server_transport); -//! tokio::spawn(server.execute(HelloServer.serve())); +//! tokio::spawn( +//! server.execute(HelloServer.serve()) +//! // Handle all requests concurrently. +//! .for_each(|response| async move { +//! tokio::spawn(response); +//! })); //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. @@ -200,7 +205,14 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -#![feature(type_alias_impl_trait)] +// For async_fn_in_trait +#![allow(incomplete_features)] +#![feature( + iter_intersperse, + type_alias_impl_trait, + async_fn_in_trait, + return_position_impl_trait_in_trait +)] #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -226,6 +238,7 @@ pub use tarpc_plugins::derive_serde; /// Rpc methods are specified, mirroring trait syntax: /// /// ``` +/// #![feature(async_fn_in_trait)] /// #[tarpc::service] /// trait Service { /// /// Say hello @@ -245,62 +258,6 @@ pub use tarpc_plugins::derive_serde; /// * `fn new_stub` -- creates a new Client stub. pub use tarpc_plugins::service; -/// A utility macro that can be used for RPC server implementations. -/// -/// Syntactic sugar to make using async functions in the server implementation -/// easier. It does this by rewriting code like this, which would normally not -/// compile because async functions are disallowed in trait implementations: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::net::SocketAddr; -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::server] -/// impl World for HelloServer { -/// async fn hello(self, _: context::Context, name: String) -> String { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// } -/// } -/// ``` -/// -/// Into code like this, which matches the service trait definition: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::pin::Pin; -/// # use futures::Future; -/// # use std::net::SocketAddr; -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// impl World for HelloServer { -/// type HelloFut = Pin + Send>>; -/// -/// fn hello(self, _: context::Context, name: String) -> Pin -/// + Send>> { -/// Box::pin(async move { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// }) -/// } -/// } -/// ``` -/// -/// Note that this won't touch functions unless they have been annotated with -/// `async`, meaning that this should not break existing code. -pub use tarpc_plugins::server; - pub(crate) mod cancellations; pub mod client; pub mod context; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d5e5b5259..2df2ab527 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -35,11 +35,6 @@ pub mod limits; /// Provides helper methods for streams of Channels. pub mod incoming; -/// Provides convenience functionality for tokio-enabled applications. -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub mod tokio; - use request_hook::{ AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, }; @@ -79,11 +74,8 @@ pub trait Serve { /// Type of response. type Resp; - /// Type of response future. - type Fut: Future>; - /// Responds to a single request. - fn serve(self, ctx: context::Context, req: Self::Req) -> Self::Fut; + async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; /// Extracts a method name from the request. fn method(&self, _request: &Self::Req) -> Option<&'static str> { @@ -274,10 +266,9 @@ where { type Req = Req; type Resp = Resp; - type Fut = Fut; - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { - (self.f)(ctx, req) + async fn serve(self, ctx: context::Context, req: Req) -> Result { + (self.f)(ctx, req).await } } @@ -533,34 +524,42 @@ where } } - /// Runs the channel until completion by executing all requests using the given service - /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's - /// default executor. + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. /// /// # Example /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; /// use futures::prelude::*; + /// use tracing_subscriber::prelude::*; + /// + /// #[derive(PartialEq, Eq, Debug)] + /// struct MyInt(i32); /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { /// let (tx, rx) = transport::channel::unbounded(); /// let client = client::new(client::Config::default(), tx).spawn(); - /// let channel = BaseChannel::new(server::Config::default(), rx); - /// tokio::spawn(channel.execute(serve(|_, i| async move { Ok(i + 1) }))); - /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// let channel = BaseChannel::with_defaults(rx); + /// tokio::spawn( + /// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!( + /// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(), + /// MyInt(2)); /// } /// ``` - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> + fn execute(self, serve: S) -> impl Stream> where Self: Sized, - S: Serve + Send + 'static, - S::Fut: Send, - Self::Req: Send + 'static, - Self::Resp: Send + 'static, + S: Serve + Clone, { self.requests().execute(serve) } @@ -573,10 +572,10 @@ where E: Error + Send + Sync + 'static, { /// An error occurred reading from, or writing to, the transport. - #[error("an error occurred in the transport: {0}")] + #[error("an error occurred in the transport")] Transport(#[source] E), /// An error occurred while polling expired requests. - #[error("an error occurred while polling expired requests: {0}")] + #[error("an error occurred while polling expired requests")] Timer(#[source] ::tokio::time::error::Error), } @@ -668,15 +667,17 @@ where Poll::Pending => Pending, }; + let status = cancellation_status + .combine(expiration_status) + .combine(request_status); + tracing::trace!( - "Expired requests: {:?}, Inbound: {:?}", - expiration_status, - request_status + "Cancellations: {cancellation_status:?}, \ + Expired requests: {expiration_status:?}, \ + Inbound: {request_status:?}, \ + Overall: {status:?}", ); - match cancellation_status - .combine(expiration_status) - .combine(request_status) - { + match status { Ready => continue, Closed => return Poll::Ready(None), Pending => return Poll::Pending, @@ -882,6 +883,51 @@ where } Poll::Ready(Some(Ok(()))) } + + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. + /// + /// If the channel encounters an error, the stream is terminated and the error is logged. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// tokio::spawn( + /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + pub fn execute(self, serve: S) -> impl Stream> + where + S: Serve + Clone, + { + self.take_while(|result| { + if let Err(e) = result { + tracing::warn!("Requests stream errored out: {}", e); + } + futures::future::ready(result.is_ok()) + }) + .filter_map(|result| async move { result.ok() }) + .map(move |request| { + let serve = serve.clone(); + request.execute(serve) + }) + } } impl fmt::Debug for Requests @@ -1039,6 +1085,13 @@ impl InFlightRequest { } } +fn print_err(e: &(dyn Error + 'static)) -> String { + anyhow::Chain::new(e) + .map(|e| e.to_string()) + .intersperse(": ".into()) + .collect::() +} + impl Stream for Requests where C: Channel, @@ -1047,17 +1100,33 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - let read = self.as_mut().pump_read(cx)?; + let read = self.as_mut().pump_read(cx).map_err(|e| { + tracing::trace!("read: {}", print_err(&e)); + e + })?; let read_closed = matches!(read, Poll::Ready(None)); - match (read, self.as_mut().pump_write(cx, read_closed)?) { + let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| { + tracing::trace!("write: {}", print_err(&e)); + e + })?; + match (read, write) { (Poll::Ready(None), Poll::Ready(None)) => { + tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)"); return Poll::Ready(None); } (Poll::Ready(Some(request_handler)), _) => { + tracing::trace!("read: Poll::Ready(Some), write: _"); return Poll::Ready(Some(Ok(request_handler))); } - (_, Poll::Ready(Some(()))) => {} - _ => { + (_, Poll::Ready(Some(()))) => { + tracing::trace!("read: _, write: Poll::Ready(Some)"); + } + (read @ Poll::Pending, write) | (read, write @ Poll::Pending) => { + tracing::trace!( + "read pending: {}, write pending: {}", + read.is_pending(), + write.is_pending() + ); return Poll::Pending; } } diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 931e87669..9195ee301 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -1,13 +1,10 @@ use super::{ limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, - Channel, + Channel, Serve, }; use futures::prelude::*; use std::{fmt, hash::Hash}; -#[cfg(feature = "tokio1")] -use super::{tokio::TokioServerExecutor, Serve}; - /// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel). pub trait Incoming where @@ -28,16 +25,62 @@ where MaxRequestsPerChannel::new(self, n) } - /// [Executes](Channel::execute) each incoming channel. Each channel will be handled - /// concurrently by spawning on tokio's default executor, and each request will be also - /// be spawned on tokio's default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioServerExecutor + /// Returns a stream of channels in execution. Each channel in execution is a stream of + /// futures, where each future is an in-flight request being rsponded to. + fn execute( + self, + serve: S, + ) -> impl Stream>> where - S: Serve, + S: Serve + Clone, { - TokioServerExecutor::new(self, serve) + self.map(move |channel| channel.execute(serve.clone())) + } +} + +#[cfg(feature = "tokio1")] +/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion. +/// Each channel is spawned, and each request from each channel is spawned. +/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended +/// for spawning streams of channels. +/// +/// # Example +/// ```rust +/// use tarpc::{ +/// context, +/// client::{self, NewClient}, +/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, +/// transport, +/// }; +/// use futures::prelude::*; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = transport::channel::unbounded(); +/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); +/// tokio::spawn(dispatch); +/// +/// let incoming = stream::once(async move { +/// BaseChannel::new(server::Config::default(), rx) +/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// tokio::spawn(spawn_incoming(incoming)); +/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); +/// } +/// ``` +pub async fn spawn_incoming( + incoming: impl Stream< + Item = impl Stream + Send + 'static> + Send + 'static, + >, +) { + use futures::pin_mut; + pin_mut!(incoming); + while let Some(channel) = incoming.next().await { + tokio::spawn(async move { + pin_mut!(channel); + while let Some(request) = channel.next().await { + tokio::spawn(request); + } + }); } } diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index a3803bade..4fd48dd4b 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -71,19 +71,17 @@ where { type Req = Serv::Req; type Resp = Serv::Resp; - type Fut = AfterRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Serv::Req) -> Self::Fut { - async move { - let AfterRequestHook { - serve, mut hook, .. - } = self; - let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; - resp - } + async fn serve( + self, + mut ctx: context::Context, + req: Serv::Req, + ) -> Result { + let AfterRequestHook { + serve, mut hook, .. + } = self; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp } } - -type AfterRequestHookFut> = - impl Future>; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 38ad54d01..2c478dbb1 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -67,18 +67,16 @@ where { type Req = Serv::Req; type Resp = Serv::Resp; - type Fut = BeforeRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Self::Req) -> Self::Fut { + async fn serve( + self, + mut ctx: context::Context, + req: Self::Req, + ) -> Result { let BeforeRequestHook { serve, mut hook, .. } = self; - async move { - hook.before(&mut ctx, &req).await?; - serve.serve(ctx, req).await - } + hook.before(&mut ctx, &req).await?; + serve.serve(ctx, req).await } } - -type BeforeRequestHookFut> = - impl Future>; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index ca42460bc..ff61a53ea 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -8,7 +8,6 @@ use super::{after::AfterRequest, before::BeforeRequest}; use crate::{context, server::Serve, ServerError}; -use futures::prelude::*; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. @@ -47,24 +46,14 @@ where { type Req = Req; type Resp = Resp; - type Fut = BeforeAndAfterRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Req) -> Self::Fut { - async move { - let BeforeAndAfterRequestHook { - serve, mut hook, .. - } = self; - hook.before(&mut ctx, &req).await?; - let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; - resp - } + async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + let BeforeAndAfterRequestHook { + serve, mut hook, .. + } = self; + hook.before(&mut ctx, &req).await?; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp } } - -type BeforeAndAfterRequestHookFut< - Req, - Resp, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, -> = impl Future>; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs deleted file mode 100644 index e9ad84221..000000000 --- a/tarpc/src/server/tokio.rs +++ /dev/null @@ -1,129 +0,0 @@ -use super::{Channel, Requests, Serve}; -use futures::{prelude::*, ready, task::*}; -use pin_project::pin_project; -use std::pin::Pin; - -/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) -/// for each new channel. Returned by -/// [`Incoming::execute`](crate::server::incoming::Incoming::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioServerExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - pub(crate) fn new(inner: T, serve: S) -> Self { - Self { inner, serve } - } -} - -/// A future that drives the server by [spawning](tokio::spawn) each [response -/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by -/// [`Channel::execute`](crate::server::Channel::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioChannelExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -impl TokioChannelExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -// Send + 'static execution helper methods. - -impl Requests -where - C: Channel, - C::Req: Send + 'static, - C::Resp: Send + 'static, -{ - /// Executes all requests using the given service function. Requests are handled concurrently - /// by [spawning](::tokio::spawn) each handler on tokio's default executor. - /// - /// # Example - /// - /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; - /// use futures::prelude::*; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); - /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); - /// let client = client::new(client::Config::default(), tx).spawn(); - /// tokio::spawn(requests.execute(serve(|_, i| async move { Ok(i + 1) }))); - /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); - /// } - /// ``` - pub fn execute(self, serve: S) -> TokioChannelExecutor - where - S: Serve + Send + 'static, - { - TokioChannelExecutor { inner: self, serve } - } -} - -impl Future for TokioServerExecutor -where - St: Sized + Stream, - C: Channel + Send + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, - Se::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { - tokio::spawn(channel.execute(self.serve.clone())); - } - tracing::info!("Server shutting down."); - Poll::Ready(()) - } -} - -impl Future for TokioChannelExecutor, S> -where - C: Channel + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, - S::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { - match response_handler { - Ok(resp) => { - let server = self.serve.clone(); - tokio::spawn(async move { - resp.execute(server).await; - }); - } - Err(e) => { - tracing::warn!("Requests stream errored out: {}", e); - break; - } - } - } - Poll::Ready(()) - } -} diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 7f3035d14..98ea0aac7 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -14,9 +14,15 @@ use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] pub enum ChannelError { - /// An error occurred sending over the channel. - #[error("an error occurred sending over the channel")] + /// An error occurred readying to send into the channel. + #[error("an error occurred readying to send into the channel")] + Ready(#[source] Box), + /// An error occurred sending into the channel. + #[error("an error occurred sending into the channel")] Send(#[source] Box), + /// An error occurred receiving from the channel. + #[error("an error occurred receiving from the channel")] + Receive(#[source] Box), } /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's @@ -48,7 +54,10 @@ impl Stream for UnboundedChannel { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.rx.poll_recv(cx).map(|option| option.map(Ok)) + self.rx + .poll_recv(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -59,7 +68,7 @@ impl Sink for UnboundedChannel { fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(if self.tx.is_closed() { - Err(ChannelError::Send(CLOSED_MESSAGE.into())) + Err(ChannelError::Ready(CLOSED_MESSAGE.into())) } else { Ok(()) }) @@ -110,7 +119,11 @@ impl Stream for Channel { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.project().rx.poll_next(cx).map(|option| option.map(Ok)) + self.project() + .rx + .poll_next(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -121,7 +134,7 @@ impl Sink for Channel { self.project() .tx .poll_ready(cx) - .map_err(|e| ChannelError::Send(Box::new(e))) + .map_err(|e| ChannelError::Ready(Box::new(e))) } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { @@ -146,8 +159,7 @@ impl Sink for Channel { } } -#[cfg(test)] -#[cfg(feature = "tokio1")] +#[cfg(all(test, feature = "tokio1"))] mod tests { use crate::{ client::{self, RpcError}, @@ -186,7 +198,10 @@ mod tests { format!("{request:?} is not an int"), ) }) - })), + })) + .for_each(|channel| async move { + tokio::spawn(channel.for_each(|response| response)); + }), ); let client = client::new(client::Config::default(), client_channel).spawn(); diff --git a/tarpc/tests/compile_fail.rs b/tarpc/tests/compile_fail.rs index 4c5a28ec9..c28fe2fa1 100644 --- a/tarpc/tests/compile_fail.rs +++ b/tarpc/tests/compile_fail.rs @@ -2,8 +2,6 @@ fn ui() { let t = trybuild::TestCases::new(); t.compile_fail("tests/compile_fail/*.rs"); - #[cfg(feature = "tokio1")] - t.compile_fail("tests/compile_fail/tokio/*.rs"); #[cfg(all(feature = "serde-transport", feature = "tcp"))] t.compile_fail("tests/compile_fail/serde_transport/*.rs"); } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d3237..18cda0d90 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,3 +1,6 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use tarpc::client; #[tarpc::service] diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index f7aa3ea6c..d12912a86 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,11 +1,11 @@ error: unused `RequestDispatch` that must be used - --> tests/compile_fail/must_use_request_dispatch.rs:13:9 + --> tests/compile_fail/must_use_request_dispatch.rs:16:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; +16 | WorldClient::new(client::Config::default(), client_transport).dispatch; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here - --> tests/compile_fail/must_use_request_dispatch.rs:11:12 + --> tests/compile_fail/must_use_request_dispatch.rs:14:12 | -11 | #[deny(unused_must_use)] +14 | #[deny(unused_must_use)] | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs b/tarpc/tests/compile_fail/tarpc_server_missing_async.rs deleted file mode 100644 index 99d858b6d..000000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs +++ /dev/null @@ -1,15 +0,0 @@ -#[tarpc::service(derive_serde = false)] -trait World { - async fn hello(name: String) -> String; -} - -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - fn hello(name: String) -> String { - format!("Hello, {name}!", name) - } -} - -fn main() {} diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr deleted file mode 100644 index d96cda833..000000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ /dev/null @@ -1,15 +0,0 @@ -error: not all trait items implemented, missing: `HelloFut` - --> tests/compile_fail/tarpc_server_missing_async.rs:9:1 - | -9 | / impl World for HelloServer { -10 | | fn hello(name: String) -> String { -11 | | format!("Hello, {name}!", name) -12 | | } -13 | | } - | |_^ - -error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> tests/compile_fail/tarpc_server_missing_async.rs:10:5 - | -10 | fn hello(name: String) -> String { - | ^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs deleted file mode 100644 index 6fc2f2bf3..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs +++ /dev/null @@ -1,29 +0,0 @@ -use tarpc::{ - context, - server::{self, Channel}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = server::BaseChannel::with_defaults(server_transport); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr deleted file mode 100644 index 446f224f6..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: unused `TokioChannelExecutor` that must be used - --> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9 - | -27 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12 - | -25 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs deleted file mode 100644 index 950cf74e6..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs +++ /dev/null @@ -1,30 +0,0 @@ -use futures::stream::once; -use tarpc::{ - context, - server::{self, incoming::Incoming}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = once(async move { server::BaseChannel::with_defaults(server_transport) }); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr deleted file mode 100644 index 07d4b5a9b..000000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: unused `TokioServerExecutor` that must be used - --> tests/compile_fail/tokio/must_use_server_executor.rs:28:9 - | -28 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_server_executor.rs:26:12 - | -26 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 365594bd4..7cd3cb8c7 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,3 +1,6 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use futures::prelude::*; use tarpc::serde_transport; use tarpc::{ @@ -21,7 +24,6 @@ pub trait ColorProtocol { #[derive(Clone)] struct ColorServer; -#[tarpc::server] impl ColorProtocol for ColorServer { async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { match color { @@ -31,6 +33,11 @@ impl ColorProtocol for ColorServer { } } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn test_call() -> anyhow::Result<()> { let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; @@ -40,7 +47,9 @@ async fn test_call() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(ColorServer.serve()), + .execute(ColorServer.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 50d19b0e9..9041aae73 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -1,13 +1,16 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use assert_matches::assert_matches; use futures::{ - future::{join_all, ready, Ready}, + future::{join_all, ready}, prelude::*, }; use std::time::{Duration, SystemTime}; use tarpc::{ client::{self}, context, - server::{self, incoming::Incoming, BaseChannel, Channel}, + server::{incoming::Incoming, BaseChannel, Channel}, transport::channel, }; use tokio::join; @@ -22,39 +25,29 @@ trait Service { struct Server; impl Service for Server { - type AddFut = Ready; - - fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut { - ready(x + y) + async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + x + y } - type HeyFut = Ready; - - fn hey(self, _: context::Context, name: String) -> Self::HeyFut { - ready(format!("Hey, {name}.")) + async fn hey(self, _: context::Context, name: String) -> String { + format!("Hey, {name}.") } } #[tokio::test] -async fn sequential() -> anyhow::Result<()> { - let _ = tracing_subscriber::fmt::try_init(); - - let (tx, rx) = channel::unbounded(); - +async fn sequential() { + let (tx, rx) = tarpc::transport::channel::unbounded(); + let client = client::new(client::Config::default(), tx).spawn(); + let channel = BaseChannel::with_defaults(rx); tokio::spawn( - BaseChannel::new(server::Config::default(), rx) - .requests() - .execute(Server.serve()), + channel + .execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) })) + .for_each(|response| response), + ); + assert_eq!( + client.call(context::current(), "AddOne", 1).await.unwrap(), + 2 ); - - let client = ServiceClient::new(client::Config::default(), tx).spawn(); - - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); - assert_matches!( - client.hey(context::current(), "Tim".into()).await, - Ok(ref s) if s == "Hey, Tim."); - - Ok(()) } #[tokio::test] @@ -70,7 +63,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { #[derive(Debug)] struct AllHandlersComplete; - #[tarpc::server] impl Loop for LoopServer { async fn r#loop(self, _: context::Context) { loop { @@ -121,7 +113,9 @@ async fn serde_tcp() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; @@ -151,7 +145,9 @@ async fn serde_uds() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; @@ -175,7 +171,9 @@ async fn concurrent() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -199,7 +197,9 @@ async fn concurrent_join() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -216,15 +216,20 @@ async fn concurrent_join() -> anyhow::Result<()> { Ok(()) } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( - stream::once(ready(rx)) - .map(BaseChannel::with_defaults) - .execute(Server.serve()), + BaseChannel::with_defaults(rx) + .execute(Server.serve()) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -249,11 +254,9 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type CountFut = futures::future::Ready; - - fn count(self, _: context::Context) -> Self::CountFut { + async fn count(self, _: context::Context) -> u32 { self.0 += 1; - futures::future::ready(self.0) + self.0 } } From 4d24dde3ec0e7f0c6e251ec66c4868eed745c360 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 23 Nov 2022 16:39:29 -0800 Subject: [PATCH 18/30] Replace actions-rs --- .github/workflows/main.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b0cc136c4..198475a23 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,7 +18,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: targets: mipsel-unknown-linux-gnu - run: cargo check --all-features @@ -33,7 +33,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly - run: cargo test - run: cargo test --manifest-path tarpc/Cargo.toml --features serde1 - run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1 @@ -50,7 +50,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: components: rustfmt - run: cargo fmt --all -- --check @@ -64,7 +64,7 @@ jobs: with: access_token: ${{ github.token }} - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@nightly with: components: clippy - run: cargo clippy --all-features -- -D warnings From 403128103da7c8716b2eb36c7c10e8d2bcb27428 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 23 Nov 2022 17:45:43 -0800 Subject: [PATCH 19/30] Remove bad mem::forget usage. mem::forget is a dangerous tool, and it was being used carelessly for things that have safer alternatives. There was at least one bug where a cloned tokio::sync::mpsc::UnboundedSender used for request cancellation was being leaked on every successful server response, so its refcounts were never decremented. Because these are atomic refcounts, they'll wrap around rather than overflow when reaching the maximum value, so I don't believe this could lead to panics or unsoundness. --- tarpc/src/client.rs | 49 +++++++++++++++++++++++++++++++++++++-------- tarpc/src/server.rs | 2 +- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index fc376e4d6..6740873d2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -18,6 +18,7 @@ use in_flight_requests::InFlightRequests; use pin_project::pin_project; use std::{ convert::TryFrom, + error::Error, fmt, pin::Pin, sync::{ @@ -166,7 +167,7 @@ impl Channel { /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver, RpcError>>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -174,7 +175,8 @@ struct ResponseGuard<'a, Resp> { /// An error that can occur in the processing of an RPC. This is not request-specific errors but /// rather cross-cutting errors that can always occur. -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub enum RpcError { /// The client disconnected from the server. #[error("the connection to the server was already shutdown")] @@ -193,6 +195,12 @@ pub enum RpcError { Server(#[from] ServerError), } +impl From for RpcError { + fn from(_: DeadlineExceededError) -> Self { + RpcError::DeadlineExceeded + } +} + impl ResponseGuard<'_, Resp> { async fn response(mut self) -> Result { let response = (&mut self.response).await; @@ -241,6 +249,7 @@ where { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); + let canceled_requests = canceled_requests; NewClient { client: Channel { @@ -277,13 +286,37 @@ pub struct RequestDispatch { config: Config, } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError +where + E: Error + Send + Sync + 'static, +{ + /// Could not read from the transport. + #[error("could not read from the transport")] + Read(#[source] E), + /// Could not ready the transport for writes. + #[error("could not ready the transport for writes")] + Ready(#[source] E), + /// Could not write to the transport. + #[error("could not write to the transport")] + Write(#[source] E), + /// Could not flush the transport. + #[error("could not flush the transport")] + Flush(#[source] E), + /// Could not close the write end of the transport. + #[error("could not close the write end of the transport")] + Close(#[source] E), + /// Could not poll expired requests. + #[error("could not poll expired requests")] + Timer(#[source] tokio::time::error::Error), +} + impl RequestDispatch where C: Transport, Response>, { - fn in_flight_requests<'a>( - self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -647,7 +680,7 @@ mod tests { .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp"); } #[tokio::test] @@ -999,8 +1032,8 @@ mod tests { async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender, RpcError>>, + response: &'a mut oneshot::Receiver, RpcError>>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 2df2ab527..43e82f4ce 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -1048,7 +1048,7 @@ impl InFlightRequest { { let Self { response_tx, - response_guard, + mut response_guard, abort_registration, span, request: From b500b727c58e6edda68dc3c909aa7f4167d8549d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Mon, 10 Apr 2023 14:28:06 +0200 Subject: [PATCH 20/30] fix rebase errors --- tarpc/examples/tls_over_tcp.rs | 6 ++-- tarpc/src/client.rs | 50 +++++----------------------------- tarpc/src/client/stub.rs | 5 +--- tarpc/src/lib.rs | 24 ++++++++++++++++ tarpc/src/server.rs | 28 ++++--------------- 5 files changed, 41 insertions(+), 72 deletions(-) diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 92d76c989..7dd314799 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -1,9 +1,12 @@ +#![feature(async_fn_in_trait)] + use rustls_pemfile::certs; use std::io::{BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient; use std::sync::Arc; +use futures::StreamExt; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore}; @@ -23,7 +26,6 @@ pub trait PingService { #[derive(Clone)] struct Service; -#[tarpc::server] impl PingService for Service { async fn ping(self, _: Context) -> String { "🔒".to_owned() @@ -103,7 +105,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(framed, Bincode::default()); let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); - tokio::spawn(fut); + tokio::spawn(fut.into_future()); } }); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 6740873d2..aa90f458e 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,10 +9,7 @@ mod in_flight_requests; pub mod stub; -use crate::{ - cancellations::{cancellations, CanceledRequests, RequestCancellation}, - context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport, -}; +use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, trace, ClientMessage, Request, Response, ServerError, Transport, ChannelError}; use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use pin_project::pin_project; @@ -167,7 +164,7 @@ impl Channel { /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver, RpcError>>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -175,8 +172,7 @@ struct ResponseGuard<'a, Resp> { /// An error that can occur in the processing of an RPC. This is not request-specific errors but /// rather cross-cutting errors that can always occur. -#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[derive(thiserror::Error, Debug)] pub enum RpcError { /// The client disconnected from the server. #[error("the connection to the server was already shutdown")] @@ -195,12 +191,6 @@ pub enum RpcError { Server(#[from] ServerError), } -impl From for RpcError { - fn from(_: DeadlineExceededError) -> Self { - RpcError::DeadlineExceeded - } -} - impl ResponseGuard<'_, Resp> { async fn response(mut self) -> Result { let response = (&mut self.response).await; @@ -286,37 +276,11 @@ pub struct RequestDispatch { config: Config, } -/// Critical errors that result in a Channel disconnecting. -#[derive(thiserror::Error, Debug)] -pub enum ChannelError -where - E: Error + Send + Sync + 'static, -{ - /// Could not read from the transport. - #[error("could not read from the transport")] - Read(#[source] E), - /// Could not ready the transport for writes. - #[error("could not ready the transport for writes")] - Ready(#[source] E), - /// Could not write to the transport. - #[error("could not write to the transport")] - Write(#[source] E), - /// Could not flush the transport. - #[error("could not flush the transport")] - Flush(#[source] E), - /// Could not close the write end of the transport. - #[error("could not close the write end of the transport")] - Close(#[source] E), - /// Could not poll expired requests. - #[error("could not poll expired requests")] - Timer(#[source] tokio::time::error::Error), -} - impl RequestDispatch where C: Transport, Response>, { - fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests> { self.as_mut().project().in_flight_requests } @@ -680,7 +644,7 @@ mod tests { .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); } #[tokio::test] @@ -1032,8 +996,8 @@ mod tests { async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender, RpcError>>, - response: &'a mut oneshot::Receiver, RpcError>>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index a8b72a20f..64b0df637 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,9 +1,6 @@ //! Provides a Stub trait, implemented by types that can call remote services. -use crate::{ - client::{Channel, RpcError}, - context, -}; +use crate::{client::{Channel, RpcError}, context}; use futures::prelude::*; pub mod load_balance; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 391c6fcaf..f7efeada3 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -270,6 +270,7 @@ pub use crate::transport::sealed::Transport; use anyhow::Context as _; use futures::task::*; use std::{error::Error, fmt::Display, io, time::SystemTime}; +use std::sync::Arc; /// A message from a client to a server. #[derive(Debug)] @@ -348,6 +349,29 @@ impl ServerError { } } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug, PartialEq, Eq)] +pub enum ChannelError + where + E: Error + Send + Sync + 'static, +{ + /// Could not read from the transport. + #[error("could not read from the transport")] + Read(#[source] Arc), + /// Could not ready the transport for writes. + #[error("could not ready the transport for writes")] + Ready(#[source] E), + /// Could not write to the transport. + #[error("could not write to the transport")] + Write(#[source] E), + /// Could not flush the transport. + #[error("could not flush the transport")] + Flush(#[source] E), + /// Could not close the write end of the transport. + #[error("could not close the write end of the transport")] + Close(#[source] E), +} + impl Request { /// Returns the deadline for this request. pub fn deadline(&self) -> &SystemTime { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 43e82f4ce..01d0642c7 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,11 +6,7 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::{ - cancellations::{cancellations, CanceledRequests, RequestCancellation}, - context::{self, SpanExt}, - trace, ClientMessage, Request, Response, ServerError, Transport, -}; +use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, trace, ClientMessage, Request, Response, ServerError, Transport, ChannelError}; use ::tokio::sync::mpsc; use futures::{ future::{AbortRegistration, Abortable}, @@ -565,20 +561,6 @@ where } } -/// Critical errors that result in a Channel disconnecting. -#[derive(thiserror::Error, Debug)] -pub enum ChannelError -where - E: Error + Send + Sync + 'static, -{ - /// An error occurred reading from, or writing to, the transport. - #[error("an error occurred in the transport")] - Transport(#[source] E), - /// An error occurred while polling expired requests. - #[error("an error occurred while polling expired requests")] - Timer(#[source] ::tokio::time::error::Error), -} - impl Stream for BaseChannel where T: Transport, ClientMessage>, @@ -635,7 +617,7 @@ where let request_status = match self .transport_pin_mut() .poll_next(cx) - .map_err(ChannelError::Transport)? + .map_err(|e| ChannelError::Read(Arc::new(e)))? { Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { @@ -729,7 +711,7 @@ where self.project() .transport .poll_close(cx) - .map_err(ChannelError::Transport) + .map_err(ChannelError::Close) } } @@ -979,7 +961,7 @@ impl InFlightRequest { pub async fn respond(self, response: Result) { let Self { response_tx, - response_guard, + mut response_guard, request: Request { id: request_id, .. }, span, .. @@ -995,7 +977,7 @@ impl InFlightRequest { // Request processing has completed, meaning either the channel canceled the request or // a request was sent back to the channel. Either way, the channel will clean up the // request data, so the request does not need to be canceled. - mem::forget(response_guard); + response_guard.cancel = true } /// Returns a [future](Future) that executes the request using the given [service From 404c65718a8d3fc95f88fbc70b3414373f8f99ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Mon, 10 Apr 2023 19:28:30 +0200 Subject: [PATCH 21/30] fix --- tarpc/src/client.rs | 40 +++++++++----------------- tarpc/src/client/stub.rs | 7 ----- tarpc/src/lib.rs | 1 - tarpc/src/server.rs | 11 ++----- tarpc/src/server/base_channel.rs | 10 +++---- tarpc/src/server/contextual_channel.rs | 19 ++++++------ 6 files changed, 29 insertions(+), 59 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index fcbe52066..5bfca6e76 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,23 +9,14 @@ mod in_flight_requests; pub mod stub; -use crate::{ - cancellations::{cancellations, CanceledRequests, RequestCancellation}, - context, trace, ClientMessage, Request, Response, ServerError, Transport, -}; +use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, trace, ClientMessage, Request, Response, ServerError, Transport, ChannelError}; use futures::{prelude::*, ready, stream::Fuse, task::*}; -use in_flight_requests::{DeadlineExceededError, InFlightRequests}; +use in_flight_requests::InFlightRequests; use pin_project::pin_project; -use std::{ - convert::TryFrom, - error::Error, - fmt, mem, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, -}; +use std::{convert::TryFrom, fmt, pin::Pin, sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}}; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -644,22 +635,17 @@ mod tests { }; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; - use std::{ - convert::TryFrom, - fmt::Display, - marker::PhantomData, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - }; + use std::{fmt::Display, marker::PhantomData, pin::Pin, sync::{ + atomic::{AtomicUsize}, + Arc, + }}; use thiserror::Error; use tokio::sync::{ mpsc::{self}, oneshot, }; use tracing::Span; + use crate::client::DefaultSequencer; #[tokio::test] async fn response_completes_request_future() { @@ -926,7 +912,7 @@ mod tests { let channel = Channel { to_dispatch, cancellation, - next_request_id: Arc::new(AtomicUsize::new(0)), + request_sequencer: Arc::new(DefaultSequencer::default()) }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) @@ -1067,7 +1053,7 @@ mod tests { impl PollTest for Poll>> where - E: fmt::Display, + E: Display, { type T = Option; diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 0c4075f40..64b0df637 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,13 +1,6 @@ //! Provides a Stub trait, implemented by types that can call remote services. -<<<<<<< HEAD -use crate::{ - client::{Channel, RpcError}, - context, -}; -======= use crate::{client::{Channel, RpcError}, context}; ->>>>>>> tikuemaster use futures::prelude::*; pub mod load_balance; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 1d537de47..c7ed9c5c7 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -312,7 +312,6 @@ pub use tarpc_plugins::service; /// /// Note that this won't touch functions unless they have been annotated with /// `async`, meaning that this should not break existing code. -pub use tarpc_plugins::server; pub(crate) mod cancellations; pub mod client; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d20f877e6..21d19c3ad 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,11 +6,7 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::{ - cancellations::{cancellations, CanceledRequests, RequestCancellation}, - context::{self, SpanExt}, - trace, ClientMessage, Request, Response, ServerError, Transport, -}; +use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, trace, ClientMessage, Request, Response, ServerError, Transport, ChannelError}; use ::tokio::sync::mpsc; use futures::{ future::{AbortRegistration, Abortable}, @@ -280,7 +276,6 @@ where { type Req = Req; type Resp = Resp; - type Fut = Fut; async fn serve(self, ctx: context::Context, req: Req) -> Result { (self.f)(ctx, req).await @@ -794,14 +789,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver<((), Response)> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, diff --git a/tarpc/src/server/base_channel.rs b/tarpc/src/server/base_channel.rs index 157576480..c179aff1c 100644 --- a/tarpc/src/server/base_channel.rs +++ b/tarpc/src/server/base_channel.rs @@ -198,7 +198,7 @@ impl Stream for BaseChannel let request_status = match self .transport_pin_mut() .poll_next(cx) - .map_err(ChannelError::Transport)? + .map_err(|e| ChannelError::Read(Arc::new(e)))? { Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { @@ -260,7 +260,7 @@ impl Sink> for BaseChannel self.project() .transport .poll_ready(cx) - .map_err(ChannelError::Transport) + .map_err(ChannelError::Ready) } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { @@ -273,7 +273,7 @@ impl Sink> for BaseChannel self.project() .transport .start_send(response) - .map_err(ChannelError::Transport) + .map_err(ChannelError::Write) } else { // If the request isn't tracked anymore, there's no need to send the response. Ok(()) @@ -285,14 +285,14 @@ impl Sink> for BaseChannel self.project() .transport .poll_flush(cx) - .map_err(ChannelError::Transport) + .map_err(ChannelError::Flush) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project() .transport .poll_close(cx) - .map_err(ChannelError::Transport) + .map_err(ChannelError::Close) } } diff --git a/tarpc/src/server/contextual_channel.rs b/tarpc/src/server/contextual_channel.rs index a3ae37be0..478f5b38e 100644 --- a/tarpc/src/server/contextual_channel.rs +++ b/tarpc/src/server/contextual_channel.rs @@ -1,8 +1,4 @@ -use crate::{ - cancellations::{cancellations, CanceledRequests, RequestCancellation}, - context::{SpanExt}, - trace, ClientMessage, Request, Response, Transport, -}; +use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{SpanExt}, trace, ClientMessage, Request, Response, Transport, ChannelError}; use futures::{ prelude::*, stream::Fuse, @@ -11,8 +7,9 @@ use futures::{ use super::in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; +use std::sync::Arc; use tracing::{info_span}; -use crate::server::{Channel, ChannelError, Config, ResponseGuard, TrackedRequest}; +use crate::server::{Channel, Config, ResponseGuard, TrackedRequest}; /// BaseChannel is the standard implementation of a [`Channel`]. /// @@ -199,7 +196,7 @@ impl Stream for ContextualChannel let request_status = match self .transport_pin_mut() .poll_next(cx) - .map_err(ChannelError::Transport)? + .map_err(|e| ChannelError::Read(Arc::new(e)))? { Poll::Ready(Some(message)) => match message { (ctx, ClientMessage::Request(request)) => { @@ -261,7 +258,7 @@ impl Sink> for ContextualChannel, response: Response) -> Result<(), Self::Error> { @@ -274,7 +271,7 @@ impl Sink> for ContextualChannel Sink> for ContextualChannel, cx: &mut Context) -> Poll> { self.project() .transport .poll_close(cx) - .map_err(ChannelError::Transport) + .map_err(ChannelError::Close) } } From f198cfb97b4d057d7726aa2f8a14ce50e8ba340e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Fri, 5 May 2023 00:05:43 +0200 Subject: [PATCH 22/30] wip --- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 4 +- tarpc/Cargo.toml | 1 + tarpc/src/client.rs | 11 +- tarpc/src/client/stub/retry.rs | 2 +- tarpc/src/context.rs | 86 ++++- tarpc/src/lib.rs | 13 +- tarpc/src/server.rs | 68 ++-- tarpc/src/server/base_channel.rs | 15 +- tarpc/src/server/contextual_channel.rs | 323 ------------------ .../src/server/limits/requests_per_channel.rs | 6 +- tarpc/src/server/request_hook/after.rs | 4 +- tarpc/src/server/request_hook/before.rs | 4 +- .../server/request_hook/before_and_after.rs | 6 +- tarpc/src/server/testing.rs | 2 +- 15 files changed, 139 insertions(+), 408 deletions(-) delete mode 100644 tarpc/src/server/contextual_channel.rs diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 6c78598be..d3f1a2600 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -38,7 +38,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index f33cea09e..154900174 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -357,7 +357,7 @@ impl<'a> ServiceGenerator<'a> { )| { quote! { #( #attrs )* - async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut tarpc::context::Context, #( #args ),*) -> #output; } }, ); @@ -430,7 +430,7 @@ impl<'a> ServiceGenerator<'a> { }) } - async fn serve(self, ctx: tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut tarpc::context::Context, req: #request_ident) -> Result<#response_ident, tarpc::ServerError> { match req { #( diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 87808776e..49472ce41 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -42,6 +42,7 @@ travis-ci = { repository = "google/tarpc" } [dependencies] anyhow = "1.0" +anymap = "0.12.1" fnv = "1.0" futures = "0.3" humantime = "2.0" diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 5bfca6e76..24423a423 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -140,7 +140,7 @@ impl Channel { skip(self, ctx, request_name, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(ctx.deadline), + rpc.deadline = %humantime::format_rfc3339(*ctx.deadline), otel.kind = "client", otel.name = request_name) )] @@ -525,12 +525,9 @@ where // buffer. let request_id = request_id; let request = ClientMessage::Request(Request { - id: request_id, + request_id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx.clone(), }); self.in_flight_requests() @@ -557,7 +554,7 @@ where let _entered = span.enter(); let cancel = ClientMessage::Cancel { - trace_context: context.trace_context, + context: context.trace_context, request_id, }; self.start_send(cancel)?; diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 46ad09685..23cc41a69 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -62,7 +62,7 @@ where for i in 1.. { let result = self .stub - .call(ctx, request_name, Arc::clone(&request)) + .call(ctx.clone(), request_name, Arc::clone(&request)) .await; if (self.should_retry)(&result, i) { tracing::trace!("Retrying on attempt {i}"); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index e3a6aff19..0b76fb0ef 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -14,6 +14,9 @@ use std::{ convert::TryFrom, time::{Duration, SystemTime}, }; +use std::hash::{Hash, Hasher}; +use std::ops::{Deref, DerefMut}; +use anymap::any::CloneAny; use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. @@ -21,44 +24,63 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug, Default)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. - #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] // Serialized as a Duration to prevent clock skew issues. #[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))] - pub deadline: SystemTime, + pub deadline: Deadline, /// Uniquely identifies requests originating from the same source. /// When a service handles a request by making requests itself, those requests should /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. pub trace_context: trace::Context, + + /// Any extra information can be requested + #[cfg_attr(feature = "serde1", serde(skip))] + pub extensions: Extensions +} + +impl PartialEq for Context { + fn eq(&self, other: &Self) -> bool { + self.trace_context == other.trace_context && self.deadline == other.deadline + } +} + +impl Eq for Context { } + +impl Hash for Context { + fn hash(&self, state: &mut H) { + self.trace_context.hash(state); + self.deadline.hash(state); + } } #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; pub use std::time::{Duration, SystemTime}; + use crate::context::Deadline; - pub fn serialize(deadline: &SystemTime, serializer: S) -> Result + pub fn serialize(deadline: &Deadline, serializer: S) -> Result where S: Serializer, { - let deadline = deadline + let deadline = deadline.0 .duration_since(SystemTime::now()) .unwrap_or(Duration::ZERO); deadline.serialize(serializer) } - pub fn deserialize<'de, D>(deserializer: D) -> Result + pub fn deserialize<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { let deadline = Duration::deserialize(deserializer)?; - Ok(SystemTime::now() + deadline) + Ok(Deadline(SystemTime::now() + deadline)) } #[cfg(test)] @@ -88,8 +110,8 @@ mod absolute_to_relative_time { assert_impl_all!(Context: Send, Sync); -fn ten_seconds_from_now() -> SystemTime { - SystemTime::now() + Duration::from_secs(10) +fn ten_seconds_from_now() -> Deadline { + Deadline(SystemTime::now() + Duration::from_secs(10)) } /// Returns the context for the current request, or a default Context if no request is active. @@ -97,12 +119,45 @@ pub fn current() -> Context { Context::current() } -#[derive(Clone)] -struct Deadline(SystemTime); +/// Deadline for executing a request +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Deadline(pub(self) SystemTime); impl Default for Deadline { fn default() -> Self { - Self(ten_seconds_from_now()) + ten_seconds_from_now() + } +} + +impl Deref for Deadline { + type Target = SystemTime; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// Extensions associated with a request +#[derive(Clone, Debug)] +pub struct Extensions(anymap::Map); + +impl Default for Extensions { + fn default() -> Self { + Self(anymap::Map::new()) + } +} + +impl Deref for Extensions { + type Target = anymap::Map; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Extensions { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -110,15 +165,16 @@ impl Context { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); + Self { trace_context: trace::Context::try_from(&span) .unwrap_or_else(|_| trace::Context::default()), + extensions: Default::default(), // span is always cloned so saving this doesn't make sense. deadline: span .context() .get::() .cloned() - .unwrap_or_default() - .0, + .unwrap_or_default(), } } @@ -146,7 +202,7 @@ impl SpanExt for tracing::Span { true, opentelemetry::trace::TraceState::default(), )) - .with_value(Deadline(context.deadline)), + .with_value(context.deadline.clone()) ); } } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index c7ed9c5c7..a823c1629 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -347,30 +347,33 @@ pub enum ClientMessage { /// The trace context associates the message with a specific chain of causally-related actions, /// possibly orchestrated across many distributed systems. #[cfg_attr(feature = "serde1", serde(default))] - trace_context: trace::Context, + context: trace::Context, /// The ID of the request to cancel. request_id: u64, }, } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. pub context: context::Context, /// Uniquely identifies the request across all requests sent over a single channel. - pub id: u64, + pub request_id: u64, /// The request body. pub message: T, } -/// A response from a server to a client. +/// A response from a server to a client.c #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Response { + /// Trace context, deadline, and other cross-cutting concerns. + #[cfg_attr(feature = "serde1", serde(skip))] + pub context: context::Context, /// The ID of the request being responded to. pub request_id: u64, /// The response body, or an error if the request failed. @@ -411,7 +414,7 @@ pub enum ChannelError E: Error + Send + Sync + 'static, { /// Could not read from the transport. - #[error("could not read from the transport")] + #[error("could not read from the transport: {0}")] Read(#[source] Arc), /// Could not ready the transport for writes. #[error("could not ready the transport for writes")] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 21d19c3ad..7294dbb20 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -19,6 +19,7 @@ use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc}; use tracing::{info_span, instrument::Instrument, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; mod in_flight_requests; pub mod request_hook; @@ -27,11 +28,9 @@ mod testing; /// Provides functionality to apply server limits. pub mod limits; -// mod base_channel; -mod contextual_channel; +mod base_channel; -// pub use base_channel::*; -pub use contextual_channel::*; +pub use base_channel::*; /// Provides helper methods for streams of Channels. pub mod incoming; @@ -65,15 +64,6 @@ impl Config { { BaseChannel::new(self, transport) } - - /// Returns a contextual channel backed by `transport` and configured with `self`. - pub fn contextual_channel(self, transport: T) -> ContextualChannel - where - T: Transport<(C, Response), (C, ClientMessage)>, - { - ContextualChannel::new(self, transport) - } - } /// Equivalent to a `FnOnce(Req) -> impl Future`. @@ -85,7 +75,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve(self, ctx: &mut context::Context, req: Self::Req) -> Result; /// Extracts a method name from the request. fn method(&self, _request: &Self::Req) -> Option<&'static str> { @@ -271,13 +261,13 @@ where impl Serve for ServeFn where - F: FnOnce(context::Context, Req) -> Fut, + F: FnOnce(&mut context::Context, Req) -> Fut, Fut: Future>, { type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -362,7 +352,7 @@ where let span = info_span!( "RPC", rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(request.context.deadline), + rpc.deadline = %humantime::format_rfc3339(*request.context.deadline), otel.kind = "server", otel.name = tracing::field::Empty, ); @@ -377,8 +367,8 @@ where let entered = span.enter(); tracing::info!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( - request.id, - request.context.deadline, + request.request_id, + *request.context.deadline, (), span.clone(), ); @@ -389,7 +379,7 @@ where abort_registration, span, response_guard: ResponseGuard { - request_id: request.id, + request_id: request.request_id, request_cancellation: self.request_cancellation.clone(), cancel: false, }, @@ -648,12 +638,12 @@ where } } ClientMessage::Cancel { - trace_context, + context, request_id, } => { if !self.in_flight_requests_mut().cancel_request(request_id) { tracing::trace!( - rpc.trace_id = %trace_context.trace_id, + rpc.trace_id = %context.trace_id, "Received cancellation, but response handler is already complete.", ); } @@ -703,6 +693,7 @@ where .remove_request(response.request_id) { let _entered = span.enter(); + tracing::error!("RSPAN = {:?}", span.metadata()); tracing::info!("SendResponse"); self.project() .transport @@ -1030,9 +1021,9 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, - id: request_id, + request_id: request_id, }, } = self; let method = serve.method(&message); @@ -1042,9 +1033,10 @@ impl InFlightRequest { span.record("otel.name", &method.unwrap_or("")); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::info!("CompleteRequest"); let response = Response { + context, request_id, message, }; @@ -1184,7 +1176,7 @@ mod tests { fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { context: context::current(), - id: 0, + request_id: 0, message: req, }) } @@ -1297,14 +1289,14 @@ mod tests { channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: () }), @@ -1320,7 +1312,7 @@ mod tests { let req0 = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1328,7 +1320,7 @@ mod tests { let req1 = channel .as_mut() .start_request(Request { - id: 1, + request_id: 1, context: context::current(), message: (), }) @@ -1351,7 +1343,7 @@ mod tests { let req = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1380,7 +1372,7 @@ mod tests { let _abort_registration = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1422,7 +1414,7 @@ mod tests { let req = channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1445,7 +1437,7 @@ mod tests { channel .as_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1508,7 +1500,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1538,7 +1530,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 1, + request_id: 1, context: context::current(), message: (), }) @@ -1559,7 +1551,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 0, + request_id: 0, context: context::current(), message: (), }) @@ -1578,7 +1570,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_request(Request { - id: 1, + request_id: 1, context: context::current(), message: (), }) diff --git a/tarpc/src/server/base_channel.rs b/tarpc/src/server/base_channel.rs index c179aff1c..f5ae4c136 100644 --- a/tarpc/src/server/base_channel.rs +++ b/tarpc/src/server/base_channel.rs @@ -11,7 +11,10 @@ use futures::{ use super::in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; +use std::sync::Arc; +use opentelemetry::trace::TraceContextExt; use tracing::{info_span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use crate::server::{Channel, ChannelError, Config, ResponseGuard, TrackedRequest}; /// BaseChannel is the standard implementation of a [`Channel`]. @@ -94,7 +97,7 @@ impl BaseChannel let span = info_span!( "RPC", rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(request.context.deadline), + rpc.deadline = %humantime::format_rfc3339(*request.context.deadline), otel.kind = "server", otel.name = tracing::field::Empty, ); @@ -109,8 +112,8 @@ impl BaseChannel let entered = span.enter(); tracing::info!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( - request.id, - request.context.deadline, + request.request_id, + *request.context.deadline, (), span.clone(), ); @@ -121,7 +124,7 @@ impl BaseChannel abort_registration, span, response_guard: ResponseGuard { - request_id: request.id, + request_id: request.request_id, request_cancellation: self.request_cancellation.clone(), cancel: false, }, @@ -214,12 +217,12 @@ impl Stream for BaseChannel } } ClientMessage::Cancel { - trace_context, + context, request_id, } => { if !self.in_flight_requests_mut().cancel_request(request_id) { tracing::trace!( - rpc.trace_id = %trace_context.trace_id, + rpc.trace_id = %context.trace_id, "Received cancellation, but response handler is already complete.", ); } diff --git a/tarpc/src/server/contextual_channel.rs b/tarpc/src/server/contextual_channel.rs deleted file mode 100644 index 478f5b38e..000000000 --- a/tarpc/src/server/contextual_channel.rs +++ /dev/null @@ -1,323 +0,0 @@ -use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{SpanExt}, trace, ClientMessage, Request, Response, Transport, ChannelError}; -use futures::{ - prelude::*, - stream::Fuse, - task::*, -}; -use super::in_flight_requests::{AlreadyExistsError, InFlightRequests}; -use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; -use std::sync::Arc; -use tracing::{info_span}; -use crate::server::{Channel, Config, ResponseGuard, TrackedRequest}; - -/// BaseChannel is the standard implementation of a [`Channel`]. -/// -/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and -/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for -/// how to use channels. -/// -/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation -/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation -/// messages. Instead, it internally handles them by cancelling corresponding requests (removing -/// the corresponding in-flight requests and aborting their handlers). -#[pin_project] -pub struct ContextualChannel { - config: Config, - /// Writes responses to the wire and reads requests off the wire. - #[pin] - transport: Fuse, - /// In-flight requests that were dropped by the server before completion. - #[pin] - canceled_requests: CanceledRequests, - /// Notifies `canceled_requests` when a request is canceled. - request_cancellation: RequestCancellation, - /// Holds data necessary to clean up in-flight requests. - in_flight_requests: InFlightRequests, - /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, -} - -impl ContextualChannel - where - T: Transport<(C, Response), (C, ClientMessage)>, -{ - /// Creates a new channel backed by `transport` and configured with `config`. - pub fn new(config: Config, transport: T) -> Self { - let (request_cancellation, canceled_requests) = cancellations(); - ContextualChannel { - config, - transport: transport.fuse(), - canceled_requests, - request_cancellation, - in_flight_requests: InFlightRequests::default(), - ghost: PhantomData, - } - } - - /// Creates a new channel backed by `transport` and configured with the defaults. - pub fn with_defaults(transport: T) -> Self { - Self::new(Config::default(), transport) - } - - /// Returns the inner transport over which messages are sent and received. - pub fn get_ref(&self) -> &T { - self.transport.get_ref() - } - - /// Returns the inner transport over which messages are sent and received. - pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> { - self.project().transport.get_pin_mut() - } - - fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { - self.as_mut().project().in_flight_requests - } - - fn canceled_requests_pin_mut<'a>( - self: &'a mut Pin<&mut Self>, - ) -> Pin<&'a mut CanceledRequests> { - self.as_mut().project().canceled_requests - } - - fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { - self.as_mut().project().transport - } - - pub(super) fn start_request( - mut self: Pin<&mut Self>, - request: (C, Request), - ) -> Result, AlreadyExistsError> { - let (context, mut request) = request; - let span = info_span!( - "RPC", - rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(request.context.deadline), - otel.kind = "server", - otel.name = tracing::field::Empty, - ); - span.set_context(&request.context); - request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { - tracing::trace!( - "OpenTelemetry subscriber not installed; making unsampled \ - child context." - ); - request.context.trace_context.new_child() - }); - let entered = span.enter(); - tracing::info!("ReceiveRequest"); - let start = self.in_flight_requests_mut().start_request( - request.id, - request.context.deadline, - context, - span.clone(), - ); - match start { - Ok(abort_registration) => { - drop(entered); - Ok(TrackedRequest { - abort_registration, - span, - response_guard: ResponseGuard { - request_id: request.id, - request_cancellation: self.request_cancellation.clone(), - cancel: false, - }, - request, - }) - } - Err(AlreadyExistsError) => { - tracing::trace!("DuplicateRequest"); - Err(AlreadyExistsError) - } - } - } -} - -impl fmt::Debug for ContextualChannel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "BaseChannel") - } -} - -impl Stream for ContextualChannel - where - T: Transport<(C, Response), (C, ClientMessage)>, -{ - type Item = Result, ChannelError>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - #[derive(Clone, Copy, Debug)] - enum ReceiverStatus { - Ready, - Pending, - Closed, - } - - impl ReceiverStatus { - fn combine(self, other: Self) -> Self { - use ReceiverStatus::*; - match (self, other) { - (Ready, _) | (_, Ready) => Ready, - (Closed, Closed) => Closed, - (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending, - } - } - } - - use ReceiverStatus::*; - - loop { - let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { - Poll::Ready(Some(request_id)) => { - if let Some((_ctx, span)) = self.in_flight_requests_mut().remove_request(request_id) { - let _entered = span.enter(); - tracing::info!("ResponseCancelled"); - } - Ready - } - // Pending cancellations don't block Channel closure, because all they do is ensure - // the Channel's internal state is cleaned up. But Channel closure also cleans up - // the Channel state, so there's no reason to wait on a cancellation before - // closing. - // - // Ready(None) can't happen, since `self` holds a Cancellation. - Poll::Pending | Poll::Ready(None) => Closed, - }; - - let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) { - // No need to send a response, since the client wouldn't be waiting for one - // anymore. - Poll::Ready(Some(_)) => Ready, - Poll::Ready(None) => Closed, - Poll::Pending => Pending, - }; - - let request_status = match self - .transport_pin_mut() - .poll_next(cx) - .map_err(|e| ChannelError::Read(Arc::new(e)))? - { - Poll::Ready(Some(message)) => match message { - (ctx, ClientMessage::Request(request)) => { - match self.as_mut().start_request((ctx,request)) { - Ok(request) => return Poll::Ready(Some(Ok(request))), - Err(AlreadyExistsError) => { - // Instead of closing the channel if a duplicate request is sent, - // just ignore it, since it's already being processed. Note that we - // cannot return Poll::Pending here, since nothing has scheduled a - // wakeup yet. - continue; - } - } - } - (_ctx, ClientMessage::Cancel { - trace_context, - request_id, - }) => { - if !self.in_flight_requests_mut().cancel_request(request_id) { - tracing::trace!( - rpc.trace_id = %trace_context.trace_id, - "Received cancellation, but response handler is already complete.", - ); - } - Ready - } - }, - Poll::Ready(None) => Closed, - Poll::Pending => Pending, - }; - - let status = cancellation_status - .combine(expiration_status) - .combine(request_status); - - tracing::trace!( - "Cancellations: {cancellation_status:?}, \ - Expired requests: {expiration_status:?}, \ - Inbound: {request_status:?}, \ - Overall: {status:?}", - ); - match status { - Ready => continue, - Closed => return Poll::Ready(None), - Pending => return Poll::Pending, - } - } - } -} - -impl Sink> for ContextualChannel - where - T: Transport<(C, Response), (C, ClientMessage)>, - T::Error: Error, -{ - type Error = ChannelError; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project() - .transport - .poll_ready(cx) - .map_err(ChannelError::Ready) - } - - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { - if let Some((ctx, span)) = self - .in_flight_requests_mut() - .remove_request(response.request_id) - { - let _entered = span.enter(); - tracing::info!("SendResponse"); - self.project() - .transport - .start_send((ctx, response)) - .map_err(ChannelError::Write) - } else { - // If the request isn't tracked anymore, there's no need to send the response. - Ok(()) - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - tracing::trace!("poll_flush"); - self.project() - .transport - .poll_flush(cx) - .map_err(ChannelError::Flush) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project() - .transport - .poll_close(cx) - .map_err(ChannelError::Close) - } -} - -impl AsRef for ContextualChannel { - fn as_ref(&self) -> &T { - self.transport.get_ref() - } -} - -implChannel for ContextualChannel - where - T: Transport<(C, Response), (C, ClientMessage)>, -{ - type Req = Req; - type Resp = Resp; - type Transport = T; - - - fn config(&self) -> &Config { - &self.config - } - - fn in_flight_requests(&self) -> usize { - self.in_flight_requests.len() - } - - fn transport(&self) -> &Self::Transport { - self.get_ref() - } -} diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 3f668878c..211ff9700 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -60,13 +60,15 @@ where match ready!(self.as_mut().project().inner.poll_next(cx)?) { Some(r) => { let _entered = r.span.enter(); + tracing::info!( in_flight_requests = self.as_mut().in_flight_requests(), "ThrottleRequest", ); self.as_mut().start_send(Response { - request_id: r.request.id, + context: r.request.context, + request_id: r.request.request_id, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: "server throttled the request.".into(), @@ -237,7 +239,7 @@ mod tests { throttler .as_mut() .poll_next(&mut testing::cx())? - .map(|r| r.map(|r| (r.request.id, r.request.message))), + .map(|r| r.map(|r| (r.request.request_id, r.request.message))), Poll::Ready(Some((0, 1))) ); Ok(()) diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index 4fd48dd4b..4c108a641 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -74,14 +74,14 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Serv::Req, ) -> Result { let AfterRequestHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 2c478dbb1..a0040e636 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -70,13 +70,13 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Self::Req, ) -> Result { let BeforeRequestHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index ff61a53ea..5acdaea6f 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -47,13 +47,13 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { let BeforeAndAfterRequestHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 4c91d0730..657d54414 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -96,7 +96,7 @@ impl FakeChannel>, Response> { deadline: SystemTime::UNIX_EPOCH, trace_context: Default::default(), }, - id, + request_id: id, message, }, abort_registration, From f12acc0527f11495f47531b7b0310020973a839e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Fri, 5 May 2023 00:28:05 +0200 Subject: [PATCH 23/30] remove request sequencer --- tarpc/src/client.rs | 42 ++++++------------------------------------ 1 file changed, 6 insertions(+), 36 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 24423a423..79410a45d 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -32,26 +32,13 @@ pub struct Config { /// `pending_requests_buffer` controls the size of the channel clients use /// to communicate with the request dispatch task. pub pending_request_buffer: usize, - - /// An implementation of RequestSequencer, to provide a unique series of request ids. - /// The default implementation generates 0,1,2,3,4,5,..., but this option can be leveraged - /// to generate less predictable results, using a block cipher for example. - pub request_sequencer: Arc } impl Default for Config { fn default() -> Self { - Self::with_sequencer(DefaultSequencer::default()) - } -} - -impl Config { - /// Create a default config with a specific sequencer - pub fn with_sequencer(s: S) -> Self { Config { max_in_flight_requests: 1_000, - pending_request_buffer: 100, - request_sequencer: Arc::new(s) + pending_request_buffer: 100 } } } @@ -89,28 +76,10 @@ impl fmt::Debug for NewClient { } } -/// Provides a stream of unique u64 numbers -pub trait RequestSequencer: fmt::Debug + Send + Sync + 'static { - - /// Generates the next number. - fn next_id(&self) -> u64; -} - const _CHECK_USIZE: () = assert!( std::mem::size_of::() <= std::mem::size_of::(), "usize is too big to fit in u64" ); -/// Default sequencer producing the numbers 0,1,2,3,4... -#[derive(Clone, Default, Debug)] -pub struct DefaultSequencer(Arc); - -impl RequestSequencer for DefaultSequencer { - fn next_id(&self) -> u64 { - //_CHECK_USIZE verifies that usize fits into an u64, and usize atomics are more likely(?) be present - // than u64 on smaller architectures. - self.0.fetch_add(1, Ordering::Relaxed) as u64 - } -} /// Handles communication from the client to request dispatch. #[derive(Debug)] @@ -119,7 +88,7 @@ pub struct Channel { /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. - request_sequencer: Arc, + next_request_id: Arc, } impl Clone for Channel { @@ -127,7 +96,7 @@ impl Clone for Channel { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), - request_sequencer: self.request_sequencer.clone(), + next_request_id: self.next_request_id.clone() } } } @@ -159,7 +128,8 @@ impl Channel { }); span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id())); let (response_completion, mut response) = oneshot::channel(); - let request_id = self.request_sequencer.next_id(); + let request_id = + u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); // ResponseGuard impls Drop to cancel in-flight requests. It should be created before // sending out the request; otherwise, the response future could be dropped after the @@ -269,7 +239,7 @@ where client: Channel { to_dispatch, cancellation, - request_sequencer: config.request_sequencer.clone(), + next_request_id: Arc::new(AtomicUsize::new(0)), }, dispatch: RequestDispatch { config, From 5f52e5c18963a1c12668d3bab97c2910c8d75973 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sat, 6 May 2023 20:50:32 +0200 Subject: [PATCH 24/30] wip internal mutability --- tarpc/src/context.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 0b76fb0ef..50634cc99 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -16,6 +16,7 @@ use std::{ }; use std::hash::{Hash, Hasher}; use std::ops::{Deref, DerefMut}; +use std::sync::{Arc, Mutex}; use anymap::any::CloneAny; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -139,11 +140,11 @@ impl Deref for Deadline { /// Extensions associated with a request #[derive(Clone, Debug)] -pub struct Extensions(anymap::Map); +pub struct Extensions(Arc>>); impl Default for Extensions { fn default() -> Self { - Self(anymap::Map::new()) + Self(Arc::new(Mutex::new(anymap::Map::new()))) } } From 03a5bf096f3cc5232542367fe735e6c60649f0ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sat, 6 May 2023 21:38:18 +0200 Subject: [PATCH 25/30] revert merge of google/master --- README.md | 2 +- RELEASES.md | 24 -- example-service/Cargo.toml | 4 +- example-service/src/client.rs | 8 +- tarpc/Cargo.toml | 16 +- tarpc/examples/certs/eddsa/client.cert | 11 - tarpc/examples/certs/eddsa/client.chain | 19 - tarpc/examples/certs/eddsa/client.key | 3 - tarpc/examples/certs/eddsa/end.cert | 12 - tarpc/examples/certs/eddsa/end.chain | 19 - tarpc/examples/certs/eddsa/end.key | 3 - tarpc/examples/pubsub.rs | 2 +- tarpc/examples/tls_over_tcp.rs | 154 -------- tarpc/src/client.rs | 341 +++++------------- tarpc/src/client/in_flight_requests.rs | 49 ++- tarpc/src/client/stub.rs | 5 +- tarpc/src/context.rs | 4 +- tarpc/src/lib.rs | 80 +--- tarpc/src/server.rs | 77 ++-- tarpc/src/server/base_channel.rs | 328 ----------------- tarpc/src/server/in_flight_requests.rs | 37 +- .../src/server/limits/requests_per_channel.rs | 3 - tarpc/src/server/testing.rs | 2 +- 23 files changed, 184 insertions(+), 1019 deletions(-) delete mode 100644 tarpc/examples/certs/eddsa/client.cert delete mode 100644 tarpc/examples/certs/eddsa/client.chain delete mode 100644 tarpc/examples/certs/eddsa/client.key delete mode 100644 tarpc/examples/certs/eddsa/end.cert delete mode 100644 tarpc/examples/certs/eddsa/end.chain delete mode 100644 tarpc/examples/certs/eddsa/end.key delete mode 100644 tarpc/examples/tls_over_tcp.rs delete mode 100644 tarpc/src/server/base_channel.rs diff --git a/README.md b/README.md index 0b5b9f46d..ea363d2c8 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ Some other features of tarpc: Add to your `Cargo.toml` dependencies: ```toml -tarpc = "0.33" +tarpc = "0.31" ``` The `tarpc::service` attribute expands to a collection of items that form an rpc service. diff --git a/RELEASES.md b/RELEASES.md index 8ea6ca378..a6ce438be 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,27 +1,3 @@ -## 0.33.0 (2023-04-01) - -### Breaking Changes - -Opentelemetry dependency version increased to 0.18. - -## 0.32.0 (2023-03-24) - -### Breaking Changes - -- As part of a fix to return more channel errors in RPC results, a few error types have changed: - - 0. `client::RpcError::Disconnected` was split into the following errors: - - Shutdown: the client was shutdown, either intentionally or due to an error. If due to an - error, pending RPCs should see the more specific errors below. - - Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent - will see this error. - - Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will - receive this error. - 0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`. - Previously, server transport errors would not indicate during which activity the transport - error occurred. Now, just like the client already was, it will be specific: reading, readying, - sending, flushing, or closing. - ## 0.31.0 (2022-11-03) ### New Features diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index 8b325a4f9..e76011176 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tarpc-example-service" -version = "0.15.0" +version = "0.13.0" rust-version = "1.56" authors = ["Tim Kuehn "] edition = "2021" @@ -21,7 +21,7 @@ futures = "0.3" opentelemetry = { version = "0.17", features = ["rt-tokio"] } opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] } rand = "0.8" -tarpc = { version = "0.33", path = "../tarpc", features = ["full"] } +tarpc = { version = "0.31", path = "../tarpc", features = ["full"] } tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } tracing = { version = "0.1" } tracing-opentelemetry = "0.17" diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 2877c8157..f59003bde 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -26,8 +26,7 @@ async fn main() -> anyhow::Result<()> { let flags = Flags::parse(); init_tracing("Tarpc Example Client")?; - let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); - transport.config_mut().max_frame_length(usize::MAX); + let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. @@ -43,10 +42,7 @@ async fn main() -> anyhow::Result<()> { .instrument(tracing::info_span!("Two Hellos")) .await; - match hello { - Ok(hello) => tracing::info!("{hello:?}"), - Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)), - } + tracing::info!("{:?}", hello); // Let the background span processor finish. sleep(Duration::from_micros(1)).await; diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 49472ce41..a36a62e78 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tarpc" -version = "0.33.0" +version = "0.31.0" rust-version = "1.58.0" authors = [ "Adam Wright ", @@ -59,8 +59,8 @@ tracing = { version = "0.1", default-features = false, features = [ "attributes", "log", ] } -tracing-opentelemetry = { version = "0.18.0", default-features = false } -opentelemetry = { version = "0.18.0", default-features = false } +tracing-opentelemetry = { version = "0.17.2", default-features = false } +opentelemetry = { version = "0.17.0", default-features = false } [dev-dependencies] @@ -69,10 +69,10 @@ bincode = "1.3" bytes = { version = "1", features = ["serde"] } flate2 = "1.0" futures-test = "0.3" -opentelemetry = { version = "0.18.0", default-features = false, features = [ +opentelemetry = { version = "0.17.0", default-features = false, features = [ "rt-tokio", ] } -opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] } +opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] } pin-utils = "0.1.0-alpha" serde_bytes = "0.11" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -80,8 +80,6 @@ tokio = { version = "1", features = ["full", "test-util", "tracing"] } console-subscriber = "0.1" tokio-serde = { version = "0.8", features = ["json", "bincode"] } trybuild = "1.0" -tokio-rustls = "0.23" -rustls-pemfile = "1.0" [package.metadata.docs.rs] all-features = true @@ -107,10 +105,6 @@ required-features = ["full"] name = "custom_transport" required-features = ["serde1", "tokio1", "serde-transport"] -[[example]] -name = "tls_over_tcp" -required-features = ["full"] - [[test]] name = "service_functional" required-features = ["serde-transport"] diff --git a/tarpc/examples/certs/eddsa/client.cert b/tarpc/examples/certs/eddsa/client.cert deleted file mode 100644 index 0d3144581..000000000 --- a/tarpc/examples/certs/eddsa/client.cert +++ /dev/null @@ -1,11 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk -RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw -NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA -NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/ -BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O -BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE -fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF -BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137 -izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/client.chain b/tarpc/examples/certs/eddsa/client.chain deleted file mode 100644 index cd760dc29..000000000 --- a/tarpc/examples/certs/eddsa/client.chain +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE -U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD -DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh -AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU -ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG -AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU -oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc -zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg= ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG -A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0 -MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh -ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU -phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR -W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC -t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/client.key b/tarpc/examples/certs/eddsa/client.key deleted file mode 100644 index a407ea841..000000000 --- a/tarpc/examples/certs/eddsa/client.key +++ /dev/null @@ -1,3 +0,0 @@ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH ------END PRIVATE KEY----- diff --git a/tarpc/examples/certs/eddsa/end.cert b/tarpc/examples/certs/eddsa/end.cert deleted file mode 100644 index b2eb159f5..000000000 --- a/tarpc/examples/certs/eddsa/end.cert +++ /dev/null @@ -1,12 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk -RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw -NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc -RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E -AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow -RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM -EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t -ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF -9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq -amD2TBup4eNUCsQB ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/end.chain b/tarpc/examples/certs/eddsa/end.chain deleted file mode 100644 index cd760dc29..000000000 --- a/tarpc/examples/certs/eddsa/end.chain +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE -U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD -DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh -AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU -ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG -AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU -oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc -zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg= ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG -A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0 -MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh -ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU -phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR -W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC -t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB ------END CERTIFICATE----- diff --git a/tarpc/examples/certs/eddsa/end.key b/tarpc/examples/certs/eddsa/end.key deleted file mode 100644 index f5541b32e..000000000 --- a/tarpc/examples/certs/eddsa/end.key +++ /dev/null @@ -1,3 +0,0 @@ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2 ------END PRIVATE KEY----- diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 5b5b2eedb..e254b294f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -289,7 +289,7 @@ impl publisher::Publisher for Publisher { /// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend. fn init_tracing(service_name: &str) -> anyhow::Result<()> { env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12"); - let tracer = opentelemetry_jaeger::new_agent_pipeline() + let tracer = opentelemetry_jaeger::new_pipeline() .with_service_name(service_name) .with_max_packet_size(2usize.pow(13)) .install_batch(opentelemetry::runtime::Tokio)?; diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs deleted file mode 100644 index 7dd314799..000000000 --- a/tarpc/examples/tls_over_tcp.rs +++ /dev/null @@ -1,154 +0,0 @@ -#![feature(async_fn_in_trait)] - -use rustls_pemfile::certs; -use std::io::{BufReader, Cursor}; -use std::net::{IpAddr, Ipv4Addr}; -use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient; - -use std::sync::Arc; -use futures::StreamExt; -use tokio::net::TcpListener; -use tokio::net::TcpStream; -use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore}; -use tokio_rustls::{webpki, TlsAcceptor, TlsConnector}; - -use tarpc::context::Context; -use tarpc::serde_transport as transport; -use tarpc::server::{BaseChannel, Channel}; -use tarpc::tokio_serde::formats::Bincode; -use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; - -#[tarpc::service] -pub trait PingService { - async fn ping() -> String; -} - -#[derive(Clone)] -struct Service; - -impl PingService for Service { - async fn ping(self, _: Context) -> String { - "🔒".to_owned() - } -} - -// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca -// used on client-side for server tls -const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain"); -// used on client-side for client-auth -const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key"); -const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert"); - -// used on server-side for server tls -const END_CERT: &str = include_str!("certs/eddsa/end.cert"); -const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key"); -// used on server-side for client-auth -const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain"); - -pub fn load_private_key(key: &str) -> rustls::PrivateKey { - let mut reader = BufReader::new(Cursor::new(key)); - loop { - match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") { - Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key), - Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key), - Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key), - None => break, - _ => {} - } - } - panic!("no keys found in {:?} (encrypted keys not supported)", key); -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - // -------------------- start here to setup tls tcp tokio stream -------------------------- - // ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs - // ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs - let cert = certs(&mut BufReader::new(Cursor::new(END_CERT))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); - let key = load_private_key(END_PRIVATEKEY); - let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000); - - // ------------- server side client_auth cert loading start - let roots: Vec = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); - let mut client_auth_roots = RootCertStore::empty(); - for root in roots { - client_auth_roots.add(&root).unwrap(); - } - let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots); - // ------------- server side client_auth cert loading end - - let config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth - .with_single_cert(cert, key) - .unwrap(); - let acceptor = TlsAcceptor::from(Arc::new(config)); - let listener = TcpListener::bind(&server_addr).await.unwrap(); - let codec_builder = LengthDelimitedCodec::builder(); - - // ref ./custom_transport.rs server side - tokio::spawn(async move { - loop { - let (stream, _peer_addr) = listener.accept().await.unwrap(); - let acceptor = acceptor.clone(); - let tls_stream = acceptor.accept(stream).await.unwrap(); - let framed = codec_builder.new_framed(tls_stream); - - let transport = transport::new(framed, Bincode::default()); - - let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); - tokio::spawn(fut.into_future()); - } - }); - - // ---------------------- client connection --------------------- - // cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113 - // tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs - let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap(); - let mut root_store = rustls::RootCertStore::empty(); - root_store.add_server_trust_anchors(chain.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - - let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH); - let client_auth_certs: Vec = - certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); - - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth - - let domain = rustls::ServerName::try_from("localhost")?; - let connector = TlsConnector::from(Arc::new(config)); - - let stream = TcpStream::connect(server_addr).await?; - let stream = connector.connect(domain, stream).await?; - - let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let answer = PingServiceClient::new(Default::default(), transport) - .spawn() - .ping(tarpc::context::current()) - .await?; - - println!("ping answer: {answer}"); - - Ok(()) -} diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 79410a45d..dfc11a641 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,14 +9,23 @@ mod in_flight_requests; pub mod stub; -use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, trace, ClientMessage, Request, Response, ServerError, Transport, ChannelError}; +use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, + context, trace, ClientMessage, Request, Response, ServerError, Transport, +}; use futures::{prelude::*, ready, stream::Fuse, task::*}; -use in_flight_requests::InFlightRequests; +use in_flight_requests::{DeadlineExceededError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, fmt, pin::Pin, sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}}; +use std::{ + convert::TryFrom, + error::Error, + fmt, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -38,7 +47,7 @@ impl Default for Config { fn default() -> Self { Config { max_in_flight_requests: 1_000, - pending_request_buffer: 100 + pending_request_buffer: 100, } } } @@ -96,7 +105,7 @@ impl Clone for Channel { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), - next_request_id: self.next_request_id.clone() + next_request_id: self.next_request_id.clone(), } } } @@ -117,7 +126,7 @@ impl Channel { &self, mut ctx: context::Context, request_name: &'static str, - request: Req + request: Req, ) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { @@ -150,7 +159,7 @@ impl Channel { response_completion, }) .await - .map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?; + .map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?; response_guard.response().await } } @@ -158,7 +167,7 @@ impl Channel { /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver, DeadlineExceededError>>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -166,17 +175,12 @@ struct ResponseGuard<'a, Resp> { /// An error that can occur in the processing of an RPC. This is not request-specific errors but /// rather cross-cutting errors that can always occur. -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub enum RpcError { /// The client disconnected from the server. - #[error("the connection to the server was already shutdown")] - Shutdown, - /// The client failed to send the request. - #[error("the client failed to send the request")] - Send(#[source] Box), - /// An error occurred while waiting for the server response. - #[error("an error occurred while waiting for the server response")] - Receive(#[source] Arc), + #[error("the client disconnected from the server")] + Disconnected, /// The request exceeded its deadline. #[error("the request exceeded its deadline")] DeadlineExceeded, @@ -185,18 +189,24 @@ pub enum RpcError { Server(#[from] ServerError), } +impl From for RpcError { + fn from(_: DeadlineExceededError) -> Self { + RpcError::DeadlineExceeded + } +} + impl ResponseGuard<'_, Resp> { async fn response(mut self) -> Result { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; match response { - Ok(response) => response, + Ok(resp) => Ok(resp?.message?), Err(oneshot::error::RecvError { .. }) => { // The oneshot is Canceled when the dispatch task ends. In that case, // there's nothing listening on the other side, so there's no point in // propagating cancellation. - Err(RpcError::Shutdown) + Err(RpcError::Disconnected) } } } @@ -265,17 +275,42 @@ pub struct RequestDispatch { /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError +where + E: Error + Send + Sync + 'static, +{ + /// Could not read from the transport. + #[error("could not read from the transport")] + Read(#[source] E), + /// Could not ready the transport for writes. + #[error("could not ready the transport for writes")] + Ready(#[source] E), + /// Could not write to the transport. + #[error("could not write to the transport")] + Write(#[source] E), + /// Could not flush the transport. + #[error("could not flush the transport")] + Flush(#[source] E), + /// Could not close the write end of the transport. + #[error("could not close the write end of the transport")] + Close(#[source] E), + /// Could not poll expired requests. + #[error("could not poll expired requests")] + Timer(#[source] tokio::time::error::Error), +} impl RequestDispatch where C: Transport, Response>, { - fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests> { + fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -335,17 +370,7 @@ where ) -> Poll>>> { self.transport_pin_mut() .poll_next(cx) - .map_err(|e| { - let e = Arc::new(e); - for span in self - .in_flight_requests() - .complete_all_requests(|| Err(RpcError::Receive(e.clone()))) - { - let _entered = span.enter(); - tracing::info!("ReceiveError"); - } - ChannelError::Read(e) - }) + .map_err(ChannelError::Read) .map_ok(|response| { self.complete(response); }) @@ -375,10 +400,7 @@ where // Receiving Poll::Ready(None) when polling expired requests never indicates "Closed", // because there can temporarily be zero in-flight rquests. Therefore, there is no need to // track the status like is done with pending and cancelled requests. - if let Poll::Ready(Some(_)) = self - .in_flight_requests() - .poll_expired(cx, || Err(RpcError::DeadlineExceeded)) - { + if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) { // Expired requests are considered complete; there is no compelling reason to send a // cancellation message to the server, since it will have already exhausted its // allotted processing time. @@ -489,7 +511,7 @@ where Some(dispatch_request) => dispatch_request, None => return Poll::Ready(None), }; - let _entered = span.enter(); + let entered = span.enter(); // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. @@ -499,17 +521,13 @@ where message: request, context: ctx.clone(), }); + self.start_send(request)?; + tracing::info!("SendRequest"); + drop(entered); self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request(request_id, ctx, span, response_completion) .expect("Request IDs should be unique"); - match self.start_send(request) { - Ok(()) => tracing::info!("SendRequest"), - Err(e) => { - self.in_flight_requests() - .complete_request(request_id, Err(RpcError::Send(Box::new(e)))); - } - } Poll::Ready(Some(Ok(()))) } @@ -534,10 +552,7 @@ where /// Sends a server response to the client task that initiated the associated request. fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { - self.in_flight_requests().complete_request( - response.request_id, - response.message.map_err(RpcError::Server), - ) + self.in_flight_requests().complete_request(response) } } @@ -586,33 +601,31 @@ struct DispatchRequest { pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender, DeadlineExceededError>>, } #[cfg(test)] mod tests { - use super::{ - cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, - }; + use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard}; use crate::{ - client::{in_flight_requests::InFlightRequests, Config}, - context::{self, current}, + client::{ + in_flight_requests::{DeadlineExceededError, InFlightRequests}, + Config, + }, + context, transport::{self, channel::UnboundedChannel}, - ChannelError, ClientMessage, Response, + ClientMessage, Response, }; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; - use std::{fmt::Display, marker::PhantomData, pin::Pin, sync::{ - atomic::{AtomicUsize}, - Arc, - }}; - use thiserror::Error; - use tokio::sync::{ - mpsc::{self}, - oneshot, + use std::{ + convert::TryFrom, + pin::Pin, + sync::atomic::{AtomicUsize, Ordering}, + sync::Arc, }; + use tokio::sync::{mpsc, oneshot}; use tracing::Span; - use crate::client::DefaultSequencer; #[tokio::test] async fn response_completes_request_future() { @@ -632,7 +645,7 @@ mod tests { .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp"); } #[tokio::test] @@ -766,185 +779,6 @@ mod tests { assert!(dispatch.as_mut().poll_next_request(cx).is_pending()); } - #[tokio::test] - async fn test_shutdown_error() { - let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, mut channel, _) = set_up(); - let (tx, mut rx) = oneshot::channel(); - // send succeeds - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; - drop(dispatch); - // error on receive - assert_matches!(resp.response().await, Err(RpcError::Shutdown)); - let (dispatch, channel, _) = set_up(); - drop(dispatch); - // error on send - let resp = channel - .call(current(), "test_request", "hi".to_string()) - .await; - assert_matches!(resp, Err(RpcError::Shutdown)); - } - - #[tokio::test] - async fn test_transport_error_write() { - let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = setup_always_err(cause); - let (tx, mut rx) = oneshot::channel(); - - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; - assert!(dispatch.as_mut().poll(&mut cx).is_pending()); - let res = resp.response().await; - assert_matches!(res, Err(RpcError::Send(_))); - let client_error: anyhow::Error = res.unwrap_err().into(); - let mut chain = client_error.chain(); - chain.next(); // original RpcError - assert_eq!( - chain - .next() - .unwrap() - .downcast_ref::>(), - Some(&ChannelError::Write(cause)) - ); - assert_eq!( - client_error.root_cause().downcast_ref::(), - Some(&cause) - ); - } - - #[tokio::test] - async fn test_transport_error_read() { - let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = setup_always_err(cause); - let (tx, mut rx) = oneshot::channel(); - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; - assert_eq!( - dispatch.as_mut().pump_write(&mut cx), - Poll::Ready(Some(Ok(()))) - ); - assert_eq!( - dispatch.as_mut().pump_read(&mut cx), - Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause))))) - ); - assert_matches!(resp.response().await, Err(RpcError::Receive(_))); - } - - #[tokio::test] - async fn test_transport_error_ready() { - let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = setup_always_err(cause); - assert_eq!( - dispatch.as_mut().poll(&mut cx), - Poll::Ready(Err(ChannelError::Ready(cause))) - ); - } - - #[tokio::test] - async fn test_transport_error_flush() { - let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = setup_always_err(cause); - assert_eq!( - dispatch.as_mut().poll(&mut cx), - Poll::Ready(Err(ChannelError::Flush(cause))) - ); - } - - #[tokio::test] - async fn test_transport_error_close() { - let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = setup_always_err(cause); - drop(channel); - assert_eq!( - dispatch.as_mut().poll(&mut cx), - Poll::Ready(Err(ChannelError::Close(cause))) - ); - } - - fn setup_always_err( - cause: TransportError, - ) -> ( - Pin>>>, - Channel, - Context<'static>, - ) { - let (to_dispatch, pending_requests) = mpsc::channel(1); - let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { - transport: transport.fuse(), - pending_requests, - canceled_requests, - in_flight_requests: InFlightRequests::default(), - config: Config::default(), - }); - let channel = Channel { - to_dispatch, - cancellation, - request_sequencer: Arc::new(DefaultSequencer::default()) - }; - let cx = Context::from_waker(noop_waker_ref()); - (dispatch, channel, cx) - } - - struct AlwaysErrorTransport(TransportError, PhantomData); - - #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] - enum TransportError { - Read, - Ready, - Write, - Flush, - Close, - } - - impl Display for TransportError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&format!("{self:?}")) - } - } - - impl Sink for AlwaysErrorTransport { - type Error = TransportError; - fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - match self.0 { - TransportError::Ready => Poll::Ready(Err(self.0)), - TransportError::Flush => Poll::Pending, - _ => Poll::Ready(Ok(())), - } - } - fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> { - if matches!(self.0, TransportError::Write) { - Err(self.0) - } else { - Ok(()) - } - } - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - if matches!(self.0, TransportError::Flush) { - Poll::Ready(Err(self.0)) - } else { - Poll::Ready(Ok(())) - } - } - fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - if matches!(self.0, TransportError::Close) { - Poll::Ready(Err(self.0)) - } else { - Poll::Ready(Ok(())) - } - } - } - - impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - if matches!(self.0, TransportError::Read) { - Poll::Ready(Some(Err(self.0))) - } else { - Poll::Pending - } - } - } - fn set_up() -> ( Pin< Box< @@ -975,7 +809,7 @@ mod tests { let channel = Channel { to_dispatch, cancellation, - request_sequencer: Arc::new(DefaultSequencer::default()) + next_request_id: Arc::new(AtomicUsize::new(0)), }; (Box::pin(dispatch), channel, server_channel) @@ -984,10 +818,11 @@ mod tests { async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender, DeadlineExceededError>>, + response: &'a mut oneshot::Receiver, DeadlineExceededError>>, ) -> ResponseGuard<'a, String> { - let request_id = channel.request_sequencer.next_id(); + let request_id = + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { ctx: context::current(), span: Span::current(), @@ -1020,7 +855,7 @@ mod tests { impl PollTest for Poll>> where - E: Display, + E: ::std::fmt::Display, { type T = Option; diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index cb69f6809..a7e5fb53b 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,6 +1,7 @@ use crate::{ context, util::{Compact, TimeUntil}, + Response, }; use fnv::FnvHashMap; use std::{ @@ -27,11 +28,17 @@ impl Default for InFlightRequests { } } +/// The request exceeded its deadline. +#[derive(thiserror::Error, Debug)] +#[non_exhaustive] +#[error("the request exceeded its deadline")] +pub struct DeadlineExceededError; + #[derive(Debug)] -struct RequestData { +struct RequestData { ctx: context::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender, DeadlineExceededError>>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -41,7 +48,7 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests { /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -58,7 +65,7 @@ impl InFlightRequests { request_id: u64, ctx: context::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender, DeadlineExceededError>>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -77,35 +84,25 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true iff the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool { - if let Some(request_data) = self.request_data.remove(&request_id) { + pub fn complete_request(&mut self, response: Response) -> bool { + if let Some(request_data) = self.request_data.remove(&response.request_id) { let _entered = request_data.span.enter(); tracing::info!("ReceiveResponse"); self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); - let _ = request_data.response_completion.send(result); + let _ = request_data.response_completion.send(Ok(response)); return true; } - tracing::debug!("No in-flight request found for request_id = {request_id}."); + tracing::debug!( + "No in-flight request found for request_id = {}.", + response.request_id + ); // If the response completion was absent, then the request was already canceled. false } - /// Completes all requests using the provided function. - /// Returns Spans for all completes requests. - pub fn complete_all_requests<'a>( - &'a mut self, - mut result: impl FnMut() -> Res + 'a, - ) -> impl Iterator + 'a { - self.deadlines.clear(); - self.request_data.drain().map(move |(_, request_data)| { - let _ = request_data.response_completion.send(result()); - request_data.span - }) - } - /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { @@ -120,18 +117,16 @@ impl InFlightRequests { /// Yields a request that has expired, completing it with a TimedOut error. /// The caller should send cancellation messages for any yielded request ID. - pub fn poll_expired( - &mut self, - cx: &mut Context, - expired_error: impl Fn() -> Res, - ) -> Poll> { + pub fn poll_expired(&mut self, cx: &mut Context) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); if let Some(request_data) = self.request_data.remove(&request_id) { let _entered = request_data.span.enter(); tracing::error!("DeadlineExceeded"); self.request_data.compact(0.1); - let _ = request_data.response_completion.send(expired_error()); + let _ = request_data + .response_completion + .send(Err(DeadlineExceededError)); } Some(request_id) }) diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 64b0df637..a8b72a20f 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,6 +1,9 @@ //! Provides a Stub trait, implemented by types that can call remote services. -use crate::{client::{Channel, RpcError}, context}; +use crate::{ + client::{Channel, RpcError}, + context, +}; use futures::prelude::*; pub mod load_balance; diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 50634cc99..f302313ad 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -140,11 +140,11 @@ impl Deref for Deadline { /// Extensions associated with a request #[derive(Clone, Debug)] -pub struct Extensions(Arc>>); +pub struct Extensions(anymap::Map); impl Default for Extensions { fn default() -> Self { - Self(Arc::new(Mutex::new(anymap::Map::new()))) + Self(anymap::Map::new()) } } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index a823c1629..c4a248195 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -258,61 +258,6 @@ pub use tarpc_plugins::derive_serde; /// * `fn new_stub` -- creates a new Client stub. pub use tarpc_plugins::service; -/// A utility macro that can be used for RPC server implementations. -/// -/// Syntactic sugar to make using async functions in the server implementation -/// easier. It does this by rewriting code like this, which would normally not -/// compile because async functions are disallowed in trait implementations: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::net::SocketAddr; -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::server] -/// impl World for HelloServer { -/// async fn hello(self, _: context::Context, name: String) -> String { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// } -/// } -/// ``` -/// -/// Into code like this, which matches the service trait definition: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::pin::Pin; -/// # use futures::Future; -/// # use std::net::SocketAddr; -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// impl World for HelloServer { -/// type HelloFut = Pin + Send>>; -/// -/// fn hello(self, _: context::Context, name: String) -> Pin -/// + Send>> { -/// Box::pin(async move { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// }) -/// } -/// } -/// ``` -/// -/// Note that this won't touch functions unless they have been annotated with -/// `async`, meaning that this should not break existing code. - pub(crate) mod cancellations; pub mod client; pub mod context; @@ -366,7 +311,7 @@ pub struct Request { pub message: T, } -/// A response from a server to a client.c +/// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] @@ -407,29 +352,6 @@ impl ServerError { } } -/// Critical errors that result in a Channel disconnecting. -#[derive(thiserror::Error, Debug, PartialEq, Eq)] -pub enum ChannelError - where - E: Error + Send + Sync + 'static, -{ - /// Could not read from the transport. - #[error("could not read from the transport: {0}")] - Read(#[source] Arc), - /// Could not ready the transport for writes. - #[error("could not ready the transport for writes")] - Ready(#[source] E), - /// Could not write to the transport. - #[error("could not write to the transport")] - Write(#[source] E), - /// Could not flush the transport. - #[error("could not flush the transport")] - Flush(#[source] E), - /// Could not close the write end of the transport. - #[error("could not close the write end of the transport")] - Close(#[source] E), -} - impl Request { /// Returns the deadline for this request. pub fn deadline(&self) -> &SystemTime { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 7294dbb20..702f66387 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,7 +6,11 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::{cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, trace, ClientMessage, Request, Response, ServerError, Transport, ChannelError}; +use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, + context::{self, SpanExt}, + trace, ClientMessage, Request, Response, ServerError, Transport, +}; use ::tokio::sync::mpsc; use futures::{ future::{AbortRegistration, Abortable}, @@ -17,9 +21,9 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; +use std::sync::Arc; use tracing::{info_span, instrument::Instrument, Span}; -use tracing_opentelemetry::OpenTelemetrySpanExt; mod in_flight_requests; pub mod request_hook; @@ -28,9 +32,6 @@ mod testing; /// Provides functionality to apply server limits. pub mod limits; -mod base_channel; - -pub use base_channel::*; /// Provides helper methods for streams of Channels. pub mod incoming; @@ -294,7 +295,7 @@ pub struct BaseChannel { /// Notifies `canceled_requests` when a request is canceled. request_cancellation: RequestCancellation, /// Holds data necessary to clean up in-flight requests. - in_flight_requests: InFlightRequests<()>, + in_flight_requests: InFlightRequests, /// Types the request and response. ghost: PhantomData<(fn() -> Req, fn(Resp))>, } @@ -331,7 +332,7 @@ where self.project().transport.get_pin_mut() } - fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<()> { + fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -369,7 +370,6 @@ where let start = self.in_flight_requests_mut().start_request( request.request_id, *request.context.deadline, - (), span.clone(), ); match start { @@ -566,6 +566,20 @@ where } } +/// Critical errors that result in a Channel disconnecting. +#[derive(thiserror::Error, Debug)] +pub enum ChannelError +where + E: Error + Send + Sync + 'static, +{ + /// An error occurred reading from, or writing to, the transport. + #[error("an error occurred in the transport")] + Transport(#[source] E), + /// An error occurred while polling expired requests. + #[error("an error occurred while polling expired requests")] + Timer(#[source] ::tokio::time::error::Error), +} + impl Stream for BaseChannel where T: Transport, ClientMessage>, @@ -596,7 +610,7 @@ where loop { let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { Poll::Ready(Some(request_id)) => { - if let Some(((), span)) = self.in_flight_requests_mut().remove_request(request_id) { + if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) { let _entered = span.enter(); tracing::info!("ResponseCancelled"); } @@ -622,7 +636,7 @@ where let request_status = match self .transport_pin_mut() .poll_next(cx) - .map_err(|e| ChannelError::Read(Arc::new(e)))? + .map_err(|e| ChannelError::Transport(e))? { Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { @@ -684,11 +698,11 @@ where self.project() .transport .poll_ready(cx) - .map_err(ChannelError::Ready) + .map_err(ChannelError::Transport) } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { - if let Some(((), span)) = self + if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) { @@ -698,7 +712,7 @@ where self.project() .transport .start_send(response) - .map_err(ChannelError::Write) + .map_err(ChannelError::Transport) } else { // If the request isn't tracked anymore, there's no need to send the response. Ok(()) @@ -710,14 +724,14 @@ where self.project() .transport .poll_flush(cx) - .map_err(ChannelError::Flush) + .map_err(ChannelError::Transport) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project() .transport .poll_close(cx) - .map_err(ChannelError::Close) + .map_err(ChannelError::Transport) } } @@ -758,9 +772,9 @@ where #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver<((), Response)>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender<((), Response)>, + responses_tx: mpsc::Sender>, } impl Requests @@ -780,14 +794,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver<((), Response)> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, @@ -799,7 +813,6 @@ where response_guard.cancel = true; InFlightRequest { request, - transport_ctx: (), abort_registration, span, response_guard, @@ -853,7 +866,7 @@ where ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { - Some(((), response)) => Poll::Ready(Some(Ok(response))), + Some(response) => Poll::Ready(Some(Ok(response))), None => { // This branch likely won't happen, since the Requests stream is holding a Sender. Poll::Ready(None) @@ -950,16 +963,15 @@ impl Drop for ResponseGuard { /// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will /// be sent to the Channel to clean up associated request state. #[derive(Debug)] -pub struct InFlightRequest { +pub struct InFlightRequest { request: Request, - transport_ctx: C, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender<(C, Response)>, + response_tx: mpsc::Sender>, } -impl InFlightRequest { +impl InFlightRequest { /// Returns a reference to the request. pub fn get(&self) -> &Request { &self.request @@ -1015,7 +1027,6 @@ impl InFlightRequest { { let Self { response_tx, - transport_ctx, mut response_guard, abort_registration, span, @@ -1040,7 +1051,7 @@ impl InFlightRequest { request_id, message, }; - let _ = response_tx.send((transport_ctx, response)).await; + let _ = response_tx.send(response).await; tracing::info!("BufferResponse"); }, abort_registration, @@ -1065,7 +1076,7 @@ impl Stream for Requests where C: Channel, { - type Item = Result, C::Error>; + type Item = Result, C::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -1519,10 +1530,10 @@ mod tests { .as_mut() .project() .responses_tx - .send(((), Response { + .send(Response { request_id: 1, message: Ok(()), - })) + }) .await .unwrap(); @@ -1579,10 +1590,10 @@ mod tests { .as_mut() .project() .responses_tx - .send(((), Response { + .send(Response { request_id: 1, message: Ok(()), - })) + }) .await .unwrap(); diff --git a/tarpc/src/server/base_channel.rs b/tarpc/src/server/base_channel.rs deleted file mode 100644 index f5ae4c136..000000000 --- a/tarpc/src/server/base_channel.rs +++ /dev/null @@ -1,328 +0,0 @@ -use crate::{ - cancellations::{cancellations, CanceledRequests, RequestCancellation}, - context::{SpanExt}, - trace, ClientMessage, Request, Response, Transport, -}; -use futures::{ - prelude::*, - stream::Fuse, - task::*, -}; -use super::in_flight_requests::{AlreadyExistsError, InFlightRequests}; -use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; -use std::sync::Arc; -use opentelemetry::trace::TraceContextExt; -use tracing::{info_span}; -use tracing_opentelemetry::OpenTelemetrySpanExt; -use crate::server::{Channel, ChannelError, Config, ResponseGuard, TrackedRequest}; - -/// BaseChannel is the standard implementation of a [`Channel`]. -/// -/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and -/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for -/// how to use channels. -/// -/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation -/// mssages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation -/// messages. Instead, it internally handles them by cancelling corresponding requests (removing -/// the corresponding in-flight requests and aborting their handlers). -#[pin_project] -pub struct BaseChannel { - config: Config, - /// Writes responses to the wire and reads requests off the wire. - #[pin] - transport: Fuse, - /// In-flight requests that were dropped by the server before completion. - #[pin] - pub(super) canceled_requests: CanceledRequests, - /// Notifies `canceled_requests` when a request is canceled. - request_cancellation: RequestCancellation, - /// Holds data necessary to clean up in-flight requests. - in_flight_requests: InFlightRequests<()>, - /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, -} - -impl BaseChannel - where - T: Transport, ClientMessage>, -{ - /// Creates a new channel backed by `transport` and configured with `config`. - pub fn new(config: Config, transport: T) -> Self { - let (request_cancellation, canceled_requests) = cancellations(); - BaseChannel { - config, - transport: transport.fuse(), - canceled_requests, - request_cancellation, - in_flight_requests: InFlightRequests::default(), - ghost: PhantomData, - } - } - - /// Creates a new channel backed by `transport` and configured with the defaults. - pub fn with_defaults(transport: T) -> Self { - Self::new(Config::default(), transport) - } - - /// Returns the inner transport over which messages are sent and received. - pub fn get_ref(&self) -> &T { - self.transport.get_ref() - } - - /// Returns the inner transport over which messages are sent and received. - pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> { - self.project().transport.get_pin_mut() - } - - fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<()> { - self.as_mut().project().in_flight_requests - } - - fn canceled_requests_pin_mut<'a>( - self: &'a mut Pin<&mut Self>, - ) -> Pin<&'a mut CanceledRequests> { - self.as_mut().project().canceled_requests - } - - fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { - self.as_mut().project().transport - } - - pub(super) fn start_request( - mut self: Pin<&mut Self>, - mut request: Request, - ) -> Result, AlreadyExistsError> { - let span = info_span!( - "RPC", - rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(*request.context.deadline), - otel.kind = "server", - otel.name = tracing::field::Empty, - ); - span.set_context(&request.context); - request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { - tracing::trace!( - "OpenTelemetry subscriber not installed; making unsampled \ - child context." - ); - request.context.trace_context.new_child() - }); - let entered = span.enter(); - tracing::info!("ReceiveRequest"); - let start = self.in_flight_requests_mut().start_request( - request.request_id, - *request.context.deadline, - (), - span.clone(), - ); - match start { - Ok(abort_registration) => { - drop(entered); - Ok(TrackedRequest { - abort_registration, - span, - response_guard: ResponseGuard { - request_id: request.request_id, - request_cancellation: self.request_cancellation.clone(), - cancel: false, - }, - request, - }) - } - Err(AlreadyExistsError) => { - tracing::trace!("DuplicateRequest"); - Err(AlreadyExistsError) - } - } - } -} - -impl fmt::Debug for BaseChannel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "BaseChannel") - } -} - -impl Stream for BaseChannel - where - T: Transport, ClientMessage>, -{ - type Item = Result, ChannelError>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - #[derive(Clone, Copy, Debug)] - enum ReceiverStatus { - Ready, - Pending, - Closed, - } - - impl ReceiverStatus { - fn combine(self, other: Self) -> Self { - use ReceiverStatus::*; - match (self, other) { - (Ready, _) | (_, Ready) => Ready, - (Closed, Closed) => Closed, - (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending, - } - } - } - - use ReceiverStatus::*; - - loop { - let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { - Poll::Ready(Some(request_id)) => { - if let Some(((), span)) = self.in_flight_requests_mut().remove_request(request_id) { - let _entered = span.enter(); - tracing::info!("ResponseCancelled"); - } - Ready - } - // Pending cancellations don't block Channel closure, because all they do is ensure - // the Channel's internal state is cleaned up. But Channel closure also cleans up - // the Channel state, so there's no reason to wait on a cancellation before - // closing. - // - // Ready(None) can't happen, since `self` holds a Cancellation. - Poll::Pending | Poll::Ready(None) => Closed, - }; - - let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) { - // No need to send a response, since the client wouldn't be waiting for one - // anymore. - Poll::Ready(Some(_)) => Ready, - Poll::Ready(None) => Closed, - Poll::Pending => Pending, - }; - - let request_status = match self - .transport_pin_mut() - .poll_next(cx) - .map_err(|e| ChannelError::Read(Arc::new(e)))? - { - Poll::Ready(Some(message)) => match message { - ClientMessage::Request(request) => { - match self.as_mut().start_request(request) { - Ok(request) => return Poll::Ready(Some(Ok(request))), - Err(AlreadyExistsError) => { - // Instead of closing the channel if a duplicate request is sent, - // just ignore it, since it's already being processed. Note that we - // cannot return Poll::Pending here, since nothing has scheduled a - // wakeup yet. - continue; - } - } - } - ClientMessage::Cancel { - context, - request_id, - } => { - if !self.in_flight_requests_mut().cancel_request(request_id) { - tracing::trace!( - rpc.trace_id = %context.trace_id, - "Received cancellation, but response handler is already complete.", - ); - } - Ready - } - }, - Poll::Ready(None) => Closed, - Poll::Pending => Pending, - }; - - let status = cancellation_status - .combine(expiration_status) - .combine(request_status); - - tracing::trace!( - "Cancellations: {cancellation_status:?}, \ - Expired requests: {expiration_status:?}, \ - Inbound: {request_status:?}, \ - Overall: {status:?}", - ); - match status { - Ready => continue, - Closed => return Poll::Ready(None), - Pending => return Poll::Pending, - } - } - } -} - -impl Sink> for BaseChannel - where - T: Transport, ClientMessage>, - T::Error: Error, -{ - type Error = ChannelError; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project() - .transport - .poll_ready(cx) - .map_err(ChannelError::Ready) - } - - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { - if let Some(((), span)) = self - .in_flight_requests_mut() - .remove_request(response.request_id) - { - let _entered = span.enter(); - tracing::info!("SendResponse"); - self.project() - .transport - .start_send(response) - .map_err(ChannelError::Write) - } else { - // If the request isn't tracked anymore, there's no need to send the response. - Ok(()) - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - tracing::trace!("poll_flush"); - self.project() - .transport - .poll_flush(cx) - .map_err(ChannelError::Flush) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project() - .transport - .poll_close(cx) - .map_err(ChannelError::Close) - } -} - -impl AsRef for BaseChannel { - fn as_ref(&self) -> &T { - self.transport.get_ref() - } -} - -implChannel for BaseChannel - where - T: Transport, ClientMessage>, -{ - type Req = Req; - type Resp = Resp; - type Transport = T; - - - fn config(&self) -> &Config { - &self.config - } - - fn in_flight_requests(&self) -> usize { - self.in_flight_requests.len() - } - - fn transport(&self) -> &Self::Transport { - self.get_ref() - } -} diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index ef535fd7f..1f8815f40 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -11,32 +11,21 @@ use tracing::Span; /// A data structure that tracks in-flight requests. It aborts requests, /// either on demand or when a request deadline expires. -#[derive(Debug)] -pub struct InFlightRequests { - request_data: FnvHashMap>, +#[derive(Debug, Default)] +pub struct InFlightRequests { + request_data: FnvHashMap, deadlines: DelayQueue, } -impl Default for InFlightRequests { - fn default() -> Self { - InFlightRequests { - request_data: FnvHashMap::with_hasher(Default::default()), - deadlines: Default::default(), - } - } -} - /// Data needed to clean up a single in-flight request. #[derive(Debug)] -struct RequestData { +struct RequestData { /// Aborts the response handler for the associated request. abort_handle: AbortHandle, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, /// The client span. span: Span, - /// Optional server side context of kept for the lifecycle of the request - context: C } /// An error returned when a request attempted to start with the same ID as a request already @@ -44,7 +33,7 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests { /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -55,7 +44,6 @@ impl InFlightRequests { &mut self, request_id: u64, deadline: SystemTime, - context: C, span: Span, ) -> Result { match self.request_data.entry(request_id) { @@ -67,7 +55,6 @@ impl InFlightRequests { abort_handle, deadline_key, span, - context, }); Ok(abort_registration) } @@ -81,7 +68,6 @@ impl InFlightRequests { span, abort_handle, deadline_key, - .. }) = self.request_data.remove(&request_id) { let _entered = span.enter(); @@ -97,11 +83,11 @@ impl InFlightRequests { /// Removes a request without aborting. Returns true iff the request was found. /// This method should be used when a response is being sent. - pub fn remove_request(&mut self, request_id: u64) -> Option<(C, Span)> { + pub fn remove_request(&mut self, request_id: u64) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); - Some((request_data.context, request_data.span)) + Some(request_data.span) } else { None } @@ -131,7 +117,7 @@ impl InFlightRequests { } /// When InFlightRequests is dropped, any outstanding requests are aborted. -impl Drop for InFlightRequests { +impl Drop for InFlightRequests { fn drop(&mut self) { self.request_data .values() @@ -155,7 +141,7 @@ mod tests { let mut in_flight_requests = InFlightRequests::default(); assert_eq!(in_flight_requests.len(), 0); in_flight_requests - .start_request(0, SystemTime::now(), (), Span::current()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); assert_eq!(in_flight_requests.len(), 1); } @@ -164,7 +150,7 @@ mod tests { async fn polling_expired_aborts() { let mut in_flight_requests = InFlightRequests::default(); let abort_registration = in_flight_requests - .start_request(0, SystemTime::now(), (), Span::current()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); @@ -186,7 +172,7 @@ mod tests { async fn cancel_request_aborts() { let mut in_flight_requests = InFlightRequests::default(); let abort_registration = in_flight_requests - .start_request(0, SystemTime::now(), (), Span::current()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); @@ -207,7 +193,6 @@ mod tests { .start_request( 0, SystemTime::now() + std::time::Duration::from_secs(10), - (), Span::current(), ) .unwrap(); diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 211ff9700..6e2df0c93 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -60,7 +60,6 @@ where match ready!(self.as_mut().project().inner.poll_next(cx)?) { Some(r) => { let _entered = r.span.enter(); - tracing::info!( in_flight_requests = self.as_mut().in_flight_requests(), "ThrottleRequest", @@ -206,7 +205,6 @@ mod tests { .start_request( i, SystemTime::now() + Duration::from_secs(1), - (), Span::current(), ) .unwrap(); @@ -330,7 +328,6 @@ mod tests { .start_request( 0, SystemTime::now() + Duration::from_secs(1), - (), Span::current(), ) .unwrap(); diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 657d54414..9c0e2bb2b 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -22,7 +22,7 @@ pub(crate) struct FakeChannel { #[pin] pub sink: VecDeque, pub config: Config, - pub in_flight_requests: super::in_flight_requests::InFlightRequests<()>, + pub in_flight_requests: super::in_flight_requests::InFlightRequests, pub request_cancellation: RequestCancellation, pub canceled_requests: CanceledRequests, } From d66bf17161796b28dacb23b472f3fac506830830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sat, 6 May 2023 21:49:17 +0200 Subject: [PATCH 26/30] revert rename of context to trace_context --- tarpc/src/client.rs | 2 +- tarpc/src/lib.rs | 2 +- tarpc/src/server.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index dfc11a641..e3ab07ca2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -542,7 +542,7 @@ where let _entered = span.enter(); let cancel = ClientMessage::Cancel { - context: context.trace_context, + trace_context: context.trace_context, request_id, }; self.start_send(cancel)?; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index c4a248195..5c0ae95ca 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -292,7 +292,7 @@ pub enum ClientMessage { /// The trace context associates the message with a specific chain of causally-related actions, /// possibly orchestrated across many distributed systems. #[cfg_attr(feature = "serde1", serde(default))] - context: trace::Context, + trace_context: trace::Context, /// The ID of the request to cancel. request_id: u64, }, diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 702f66387..22fe4f2cd 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -652,7 +652,7 @@ where } } ClientMessage::Cancel { - context, + trace_context: context, request_id, } => { if !self.in_flight_requests_mut().cancel_request(request_id) { From cc673118596ac18ca835b3893df2d308e4209774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sat, 6 May 2023 21:53:32 +0200 Subject: [PATCH 27/30] fixes --- tarpc/examples/tracing.rs | 2 +- tarpc/src/lib.rs | 1 - tarpc/src/server.rs | 10 ++++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index d37fbabea..c0a1f8d33 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -80,7 +80,7 @@ where } fn init_tracing(service_name: &str) -> anyhow::Result<()> { - let tracer = opentelemetry_jaeger::new_agent_pipeline() + let tracer = opentelemetry_jaeger::new_pipeline() .with_service_name(service_name) .with_auto_split_batch(true) .with_max_packet_size(2usize.pow(13)) diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 5c0ae95ca..130dd4fe2 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -270,7 +270,6 @@ pub use crate::transport::sealed::Transport; use anyhow::Context as _; use futures::task::*; use std::{error::Error, fmt::Display, io, time::SystemTime}; -use std::sync::Arc; /// A message from a client to a server. #[derive(Debug)] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 22fe4f2cd..961401b77 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -22,7 +22,6 @@ use futures::{ use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; -use std::sync::Arc; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; @@ -636,7 +635,7 @@ where let request_status = match self .transport_pin_mut() .poll_next(cx) - .map_err(|e| ChannelError::Transport(e))? + .map_err(ChannelError::Transport)? { Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { @@ -707,7 +706,6 @@ where .remove_request(response.request_id) { let _entered = span.enter(); - tracing::error!("RSPAN = {:?}", span.metadata()); tracing::info!("SendResponse"); self.project() .transport @@ -809,6 +807,10 @@ where span, mut response_guard, }| { + { + let _entered = span.enter(); + tracing::info!("BeginRequest"); + } // The response guard becomes active once in an InFlightRequest. response_guard.cancel = true; InFlightRequest { @@ -1034,7 +1036,7 @@ impl InFlightRequest { Request { mut context, message, - request_id: request_id, + request_id, }, } = self; let method = serve.method(&message); From 02c276a6d077beb283af1b7cdd6cc24609d51a86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sat, 6 May 2023 22:23:58 +0200 Subject: [PATCH 28/30] fix tests --- README.md | 2 +- plugins/tests/service.rs | 12 +-- tarpc/examples/compression.rs | 2 +- tarpc/examples/custom_transport.rs | 2 +- tarpc/examples/pubsub.rs | 6 +- tarpc/examples/readme.rs | 2 +- tarpc/examples/tracing.rs | 6 +- tarpc/src/client.rs | 9 +- tarpc/src/context.rs | 21 ++++- tarpc/src/lib.rs | 7 +- tarpc/src/server.rs | 91 ++++++++++--------- .../src/server/limits/requests_per_channel.rs | 9 +- tarpc/src/server/testing.rs | 4 +- tarpc/tests/dataservice.rs | 2 +- tarpc/tests/service_functional.rs | 10 +- 15 files changed, 108 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index ea363d2c8..419e4e8b0 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ impl World for HelloServer { type HelloFut = Ready; - fn hello(self, _: context::Context, name: String) -> Self::HelloFut { + fn hello(self, _: &mut context::Context, name: String) -> Self::HelloFut { future::ready(format!("Hello, {name}!")) } } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 38bd7f0dc..008393f69 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -13,15 +13,15 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + async fn two_part(self, _: &mut context::Context, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut context::Context, s: String) -> String { s } - async fn baz(self, _: context::Context) { + async fn baz(self, _: &mut context::Context) { () } } @@ -42,18 +42,18 @@ fn raw_idents() { impl r#trait for () { async fn r#await( self, - _: context::Context, + _: &mut context::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut context::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) { + async fn r#async(self, _: &mut context::Context) { () } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index cc993f0af..4212e351e 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -109,7 +109,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}!") } } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 2c5fd4dc4..40a39f845 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -24,7 +24,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + async fn ping(self, _: &mut Context) {} } #[tokio::main] diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index e254b294f..ec390c01f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -83,11 +83,11 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + async fn topics(self, _: &mut context::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -266,7 +266,7 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c6ef61eb4..50db51350 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -26,7 +26,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hello, {name}!") } } diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index c0a1f8d33..c4bd66ad0 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -56,7 +56,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } } @@ -71,7 +71,7 @@ where Stub: AddStub + Clone + Send + Sync + 'static, for<'a> Stub::RespFut<'a>: Send, { - async fn double(self, _: context::Context, x: i32) -> Result { + async fn double(self, _: &mut context::Context, x: i32) -> Result { self.add_client .add(context::current(), x, x) .await @@ -182,7 +182,7 @@ async fn main() -> anyhow::Result<()> { let ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!("{:?}", double_client.double(ctx.clone(), 1).await?); } opentelemetry::global::shutdown_tracer_provider(); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index e3ab07ca2..cc80e1beb 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -633,19 +633,22 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); + let ctx = context::current(); + dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request(0, ctx.clone(), Span::current(), tx) .unwrap(); server_channel .send(Response { request_id: 0, + context: ctx, message: Ok("Resp".into()), }) .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp), context: ctx })) if resp == "Resp"); } #[tokio::test] @@ -669,6 +672,7 @@ mod tests { let (tx, mut response) = oneshot::channel(); tx.send(Ok(Response { request_id: 0, + context: context::current(), message: Ok("well done"), })) .unwrap(); @@ -719,6 +723,7 @@ mod tests { &mut server_channel, Response { request_id: 0, + context: context::current(), message: Ok("hello".into()), }, ) diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index f302313ad..e8ddeeabd 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -85,13 +85,13 @@ mod absolute_to_relative_time { } #[cfg(test)] - #[derive(serde::Serialize, serde::Deserialize)] - struct AbsoluteToRelative(#[serde(with = "self")] SystemTime); + #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] + struct AbsoluteToRelative(#[serde(with = "self")] Deadline); #[test] fn test_serialize() { let now = SystemTime::now(); - let deadline = now + Duration::from_secs(10); + let deadline = Deadline(now + Duration::from_secs(10)); let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap(); let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap(); // TODO: how to avoid flakiness? @@ -105,7 +105,7 @@ mod absolute_to_relative_time { let AbsoluteToRelative(deserialized_deadline) = bincode::deserialize(&serialized_deadline).unwrap(); // TODO: how to avoid flakiness? - assert!(deserialized_deadline > SystemTime::now() + Duration::from_secs(9)); + assert!(*deserialized_deadline > SystemTime::now() + Duration::from_secs(9)); } } @@ -138,6 +138,19 @@ impl Deref for Deadline { } } +impl DerefMut for Deadline { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Deadline { + /// Creates a new deadline + pub fn new(t: SystemTime) -> Deadline { + Deadline(t) + } +} + /// Extensions associated with a request #[derive(Clone, Debug)] pub struct Extensions(anymap::Map); diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 130dd4fe2..923b4d02f 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -131,7 +131,7 @@ //! //! impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut context::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -167,7 +167,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # async fn hello(self, _: &mut context::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -213,6 +213,9 @@ async_fn_in_trait, return_position_impl_trait_in_trait )] + +#![cfg_attr(feature = "serde1", feature(async_closure))] + #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 961401b77..630475ebb 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -112,7 +112,8 @@ pub trait Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut ctx = context::current(); + /// let response = serve.serve(&mut ctx, 1); /// assert!(block_on(response).is_err()); /// ``` fn before(self, hook: Hook) -> BeforeRequestHook @@ -155,7 +156,8 @@ pub trait Serve { /// future::ready(()) /// }); /// - /// let response = serve.serve(context::current(), 1); + /// let mut ctx = context::current(); + /// let response = serve.serve(&mut ctx, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> AfterRequestHook @@ -210,7 +212,8 @@ pub trait Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// let mut ctx = context::current(); + /// let response = serve.serve(&mut ctx, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( @@ -250,7 +253,7 @@ impl Copy for ServeFn where F: Copy {} /// Result>`. pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, + F: FnOnce(&mut context::Context, Req) -> Fut, Fut: Future>, { ServeFn { @@ -1140,6 +1143,7 @@ mod tests { task::Poll, time::{Duration, Instant, SystemTime}, }; + use crate::context::Deadline; fn test_channel() -> ( Pin, Response>>>>, @@ -1203,40 +1207,40 @@ mod tests { #[tokio::test] async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); - } - - #[tokio::test] - async fn serve_before_mutates_context() -> anyhow::Result<()> { - struct SetDeadline(SystemTime); - type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; - impl BeforeRequest for SetDeadline { - type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; - fn before<'a>( - &'a mut self, - ctx: &'a mut context::Context, - _: &'a Req, - ) -> Self::Fut<'a> { - async move { - ctx.deadline = self.0; - Ok(()) - } - } - } - - let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); - let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); - - let serve = serve(move |ctx: context::Context, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - }); - let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::current(); - ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; - Ok(()) - } + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); + } + + // #[tokio::test] + // async fn serve_before_mutates_context() -> anyhow::Result<()> { + // struct SetDeadline(SystemTime); + // type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; + // impl BeforeRequest for SetDeadline { + // type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; + // fn before<'a>( + // &'a mut self, + // ctx: &'a mut context::Context, + // _: &'a Req, + // ) -> Self::Fut<'a> { + // async move { + // *ctx.deadline = self.0; + // Ok(()) + // } + // } + // } + // + // let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); + // let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); + // + // let serve = serve(async move |ctx: &mut context::Context, i| { + // assert_eq!(*ctx.deadline, some_time); + // Ok(i) + // }); + // let deadline_hook = serve.before(SetDeadline(some_time)); + // let mut ctx = context::current(); + // *ctx.deadline = some_other_time; + // deadline_hook.serve(&mut ctx, 7).await?; + // Ok(()) + // } #[tokio::test] async fn serve_before_and_after() -> anyhow::Result<()> { @@ -1276,10 +1280,10 @@ mod tests { } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } @@ -1290,7 +1294,7 @@ mod tests { let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1459,6 +1463,7 @@ mod tests { channel .as_mut() .start_send(Response { + context: context::current(), request_id: 0, message: Ok(()), }) @@ -1522,6 +1527,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_send(Response { + context: context::current(), request_id: 0, message: Ok(()), }) @@ -1533,6 +1539,7 @@ mod tests { .project() .responses_tx .send(Response { + context: context::current(), request_id: 1, message: Ok(()), }) @@ -1573,6 +1580,7 @@ mod tests { .as_mut() .channel_pin_mut() .start_send(Response { + context: context::current(), request_id: 0, message: Ok(()), }) @@ -1593,6 +1601,7 @@ mod tests { .project() .responses_tx .send(Response { + context: context::current(), request_id: 1, message: Ok(()), }) diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 6e2df0c93..a6d5e3443 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -189,6 +189,7 @@ mod tests { time::{Duration, SystemTime}, }; use tracing::Span; + use crate::context; #[tokio::test] async fn throttler_in_flight_requests() { @@ -336,15 +337,13 @@ mod tests { .start_send(Response { request_id: 0, message: Ok(1), + context: context::current() }) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 0); assert_eq!( - throttler.inner.sink.get(0), - Some(&Response { - request_id: 0, - message: Ok(1), - }) + throttler.inner.sink.get(0).map(|resp| (resp.request_id, &resp.message)), + Some((0, &Ok(1))), ); } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 9c0e2bb2b..82bfb71db 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -14,6 +14,7 @@ use futures::{task::*, Sink, Stream}; use pin_project::pin_project; use std::{collections::VecDeque, io, pin::Pin, time::SystemTime}; use tracing::Span; +use crate::context::Deadline; #[pin_project] pub(crate) struct FakeChannel { @@ -93,8 +94,9 @@ impl FakeChannel>, Response> { self.stream.push_back(Ok(TrackedRequest { request: Request { context: context::Context { - deadline: SystemTime::UNIX_EPOCH, + deadline: Deadline::new(SystemTime::UNIX_EPOCH), trace_context: Default::default(), + extensions: Default::default() }, request_id: id, message, diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 7cd3cb8c7..2473fdd85 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -25,7 +25,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 9041aae73..598f28945 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -25,11 +25,11 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -64,7 +64,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct AllHandlersComplete; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + async fn r#loop(self, _: &mut context::Context) { loop { futures::pending!(); } @@ -81,7 +81,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let client = LoopClient::new(client::Config::default(), tx).spawn(); let mut ctx = context::current(); - ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60); + *ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(ctx).await; }); @@ -254,7 +254,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + async fn count(self, _: &mut context::Context) -> u32 { self.0 += 1; self.0 } From aa65d58df4d747c637b4696fd39180b152ec7985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Mon, 6 Nov 2023 20:59:27 +0100 Subject: [PATCH 29/30] fix compile errors --- tarpc/src/client/stub.rs | 3 ++- tarpc/src/client/stub/load_balance.rs | 4 ++-- tarpc/src/client/stub/retry.rs | 5 +++-- tarpc/src/context.rs | 1 - tarpc/src/lib.rs | 3 --- tarpc/src/server.rs | 2 +- 6 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index a8b72a20f..894e2efb2 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -53,4 +53,5 @@ impl Stub for Channel { } } -type RespFut<'a, Req: 'a, Resp: 'a> = impl Future> + 'a; +/// A type alias for a response future +pub type RespFut<'a, Req: 'a, Resp: 'a> = impl Future> + 'a; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index c9005a423..2fa67ec82 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -31,7 +31,7 @@ mod round_robin { } } - type RespFut<'a, Stub: stub::Stub + 'a> = + pub type RespFut<'a, Stub: stub::Stub + 'a> = impl Future> + 'a; /// A Stub that load-balances across backing stubs by round robin. @@ -145,7 +145,7 @@ mod consistent_hash { } } - type RespFut<'a, Stub: stub::Stub + 'a> = + pub type RespFut<'a, Stub: stub::Stub + 'a> = impl Future> + 'a; /// A Stub that load-balances across backing stubs by round robin. diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 23cc41a69..138a19227 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -14,7 +14,7 @@ where { type Req = Req; type Resp = Stub::Resp; - type RespFut<'a> = RespFut<'a, Stub, Self::Req, F> + type RespFut<'a> = RespFut<'a, Stub, Req, F> where Self: 'a, Self::Req: 'a; @@ -28,7 +28,8 @@ where } } -type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> = +/// A type alias for a response future +pub type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> = impl Future> + 'a; /// A Stub that retries requests based on response contents. diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index e8ddeeabd..2d90b33c4 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -16,7 +16,6 @@ use std::{ }; use std::hash::{Hash, Hasher}; use std::ops::{Deref, DerefMut}; -use std::sync::{Arc, Mutex}; use anymap::any::CloneAny; use tracing_opentelemetry::OpenTelemetrySpanExt; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 923b4d02f..af344ee47 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -210,10 +210,7 @@ #![feature( iter_intersperse, type_alias_impl_trait, - async_fn_in_trait, - return_position_impl_trait_in_trait )] - #![cfg_attr(feature = "serde1", feature(async_closure))] #![deny(missing_docs)] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 630475ebb..18b70f583 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -75,7 +75,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: &mut context::Context, req: Self::Req) -> Result; + fn serve(self, ctx: &mut context::Context, req: Self::Req) -> impl Future>; /// Extracts a method name from the request. fn method(&self, _request: &Self::Req) -> Option<&'static str> { From a3874ddee62e0731ca79c34c4c2f5b6d72b5d835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Mon, 6 Nov 2023 21:21:14 +0100 Subject: [PATCH 30/30] fix tests --- example-service/src/lib.rs | 3 - example-service/src/server.rs | 4 -- plugins/tests/server.rs | 3 - plugins/tests/service.rs | 3 - tarpc/examples/compression.rs | 3 - tarpc/examples/custom_transport.rs | 3 - tarpc/examples/pubsub.rs | 32 --------- tarpc/examples/readme.rs | 4 -- tarpc/examples/tracing.rs | 4 -- tarpc/src/client.rs | 2 +- tarpc/src/lib.rs | 9 --- tarpc/src/server.rs | 68 ++++++++++--------- .../compile_fail/must_use_request_dispatch.rs | 3 - .../must_use_request_dispatch.stderr | 12 ++-- .../must_use_tcp_connect.stderr | 4 ++ tarpc/tests/dataservice.rs | 3 - tarpc/tests/service_functional.rs | 3 - 17 files changed, 49 insertions(+), 114 deletions(-) diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index 822d8217b..bc38fe93e 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,9 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use std::env; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index d3f1a2600..3cd4b43fe 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -3,10 +3,6 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use clap::Parser; use futures::{future, prelude::*}; use rand::{ diff --git a/plugins/tests/server.rs b/plugins/tests/server.rs index 7fcec793e..9d4129693 100644 --- a/plugins/tests/server.rs +++ b/plugins/tests/server.rs @@ -1,6 +1,3 @@ -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - // these need to be out here rather than inside the function so that the // assert_type_eq macro can pick them up. #[tarpc::service] diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 008393f69..afb62ce83 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,6 +1,3 @@ -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use tarpc::context; #[test] diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 4212e351e..400849782 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -4,9 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 40a39f845..076866116 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -4,9 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use futures::prelude::*; use tarpc::context::Context; use tarpc::serde_transport as transport; diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index ec390c01f..8e80a9126 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -3,38 +3,6 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - -/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" -/// port. Because both publishers and subscribers initiate their connections to the PubSub -/// server, the server requires no prior knowledge of either publishers or subscribers. -/// -/// - Subscribers connect to the server on the server's "subscriber" port. Once a connection is -/// established, the server acts as the client of the Subscriber service, initially requesting -/// the topics the subscriber is interested in, and subsequently sending topical messages to the -/// subscriber. -/// -/// - Publishers connect to the server on the "publisher" port and, once connected, they send -/// topical messages via Publisher service to the server. The server then broadcasts each -/// messages to all clients subscribed to the topic of that message. -/// -/// Subscriber Publisher PubSub Server -/// T1 | | | -/// T2 |-----Connect------------------------------------------------------>| -/// T3 | | | -/// T2 |<-------------------------------------------------------Topics-----| -/// T2 |-----(OK) Topics-------------------------------------------------->| -/// T3 | | | -/// T4 | |-----Connect-------------------->| -/// T5 | | | -/// T6 | |-----Publish-------------------->| -/// T7 | | | -/// T8 |<------------------------------------------------------Receive-----| -/// T9 |-----(OK) Receive------------------------------------------------->| -/// T10 | | | -/// T11 | |<--------------(OK) Publish------| use anyhow::anyhow; use futures::{ channel::oneshot, diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 50db51350..dcae39438 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -3,10 +3,6 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use futures::prelude::*; use tarpc::{ client, context, diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index c4bd66ad0..78846e19d 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -3,10 +3,6 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - -#![allow(incomplete_features)] -#![feature(async_fn_in_trait, type_alias_impl_trait)] - use crate::{ add::{Add as AddService, AddStub}, double::Double as DoubleService, diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index cc80e1beb..dba5833ad 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -648,7 +648,7 @@ mod tests { .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp), context: ctx })) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp), context: _ctx })) if resp == "Resp"); } #[tokio::test] diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index af344ee47..91ac63c9a 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -80,8 +80,6 @@ //! First, let's set up the dependencies and service definition. //! //! ```rust -//! #![allow(incomplete_features)] -//! #![feature(async_fn_in_trait)] //! # extern crate futures; //! //! use futures::{ @@ -106,8 +104,6 @@ //! implement it for our Server struct. //! //! ```rust -//! # #![allow(incomplete_features)] -//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -143,8 +139,6 @@ //! available behind the `tcp` feature. //! //! ```rust -//! # #![allow(incomplete_features)] -//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -205,8 +199,6 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -// For async_fn_in_trait -#![allow(incomplete_features)] #![feature( iter_intersperse, type_alias_impl_trait, @@ -238,7 +230,6 @@ pub use tarpc_plugins::derive_serde; /// Rpc methods are specified, mirroring trait syntax: /// /// ``` -/// #![feature(async_fn_in_trait)] /// #[tarpc::service] /// trait Service { /// /// Say hello diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 18b70f583..1691a574a 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -1143,7 +1143,7 @@ mod tests { task::Poll, time::{Duration, Instant, SystemTime}, }; - use crate::context::Deadline; + use std::ops::Deref; fn test_channel() -> ( Pin, Response>>>>, @@ -1210,37 +1210,41 @@ mod tests { assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } - // #[tokio::test] - // async fn serve_before_mutates_context() -> anyhow::Result<()> { - // struct SetDeadline(SystemTime); - // type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; - // impl BeforeRequest for SetDeadline { - // type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; - // fn before<'a>( - // &'a mut self, - // ctx: &'a mut context::Context, - // _: &'a Req, - // ) -> Self::Fut<'a> { - // async move { - // *ctx.deadline = self.0; - // Ok(()) - // } - // } - // } - // - // let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); - // let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); - // - // let serve = serve(async move |ctx: &mut context::Context, i| { - // assert_eq!(*ctx.deadline, some_time); - // Ok(i) - // }); - // let deadline_hook = serve.before(SetDeadline(some_time)); - // let mut ctx = context::current(); - // *ctx.deadline = some_other_time; - // deadline_hook.serve(&mut ctx, 7).await?; - // Ok(()) - // } + #[tokio::test] + async fn serve_before_mutates_context() -> anyhow::Result<()> { + struct SetDeadline(SystemTime); + type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; + impl BeforeRequest for SetDeadline { + type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>( + &'a mut self, + ctx: &'a mut context::Context, + _: &'a Req, + ) -> Self::Fut<'a> { + async move { + *ctx.deadline = self.0; + Ok(()) + } + } + } + + let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); + let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); + + let serve = serve(|ctx: &mut context::Context, i| { + let deadline = ctx.deadline.deref().clone(); + + async move { + assert_eq!(deadline, some_time); + Ok(i) + } + }); + let deadline_hook = serve.before(SetDeadline(some_time)); + let mut ctx = context::current(); + *ctx.deadline = some_other_time; + deadline_hook.serve(&mut ctx, 7).await?; + Ok(()) + } #[tokio::test] async fn serve_before_and_after() -> anyhow::Result<()> { diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 18cda0d90..2915d3237 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,6 +1,3 @@ -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use tarpc::client; #[tarpc::service] diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index d12912a86..e652cc8e8 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,11 +1,15 @@ error: unused `RequestDispatch` that must be used - --> tests/compile_fail/must_use_request_dispatch.rs:16:9 + --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -16 | WorldClient::new(client::Config::default(), client_transport).dispatch; +13 | WorldClient::new(client::Config::default(), client_transport).dispatch; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here - --> tests/compile_fail/must_use_request_dispatch.rs:14:12 + --> tests/compile_fail/must_use_request_dispatch.rs:11:12 | -14 | #[deny(unused_must_use)] +11 | #[deny(unused_must_use)] | ^^^^^^^^^^^^^^^ +help: use `let _ = ...` to ignore the resulting value + | +13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; + | +++++++ diff --git a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr index d3f4eb62a..b6e9bdeff 100644 --- a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr +++ b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr @@ -9,3 +9,7 @@ note: the lint level is defined here | 5 | #[deny(unused_must_use)] | ^^^^^^^^^^^^^^^ +help: use `let _ = ...` to ignore the resulting value + | +7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default); + | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 2473fdd85..78e28cd77 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,6 +1,3 @@ -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use futures::prelude::*; use tarpc::serde_transport; use tarpc::{ diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 598f28945..58066b411 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -1,6 +1,3 @@ -#![allow(incomplete_features)] -#![feature(async_fn_in_trait)] - use assert_matches::assert_matches; use futures::{ future::{join_all, ready},