diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 583d43d0e..f54b73e2c 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -736,7 +736,12 @@ impl ServiceGenerator<'_> { #vis fn new(config: ::tarpc::client::Config, transport: T) -> ::tarpc::client::NewClient< Self, - ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> + ::tarpc::client::RequestDispatch< + #request_ident, + #response_ident, + T, + ::tarpc::util::delay_queue::DelayQueue + > > where T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>> diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 14b5035c0..b84085aa8 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -12,7 +12,10 @@ pub mod stub; use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, trace, - util::TimeUntil, + util::{ + delay_queue::{DelayQueue, DelayQueueLike}, + TimeUntil, + }, ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, }; use futures::{prelude::*, ready, stream::Fuse, task::*}; @@ -237,9 +240,27 @@ impl Drop for ResponseGuard<'_, Resp> { pub fn new( config: Config, transport: C, -) -> NewClient, RequestDispatch> +) -> NewClient, RequestDispatch>> where C: Transport, Response>, +{ + with_in_flight_requests( + config, + transport, + InFlightRequests::<_, DelayQueue>::default(), + ) +} + +/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the +/// channel. +pub fn with_in_flight_requests( + config: Config, + transport: C, + in_flight_requests: InFlightRequests, Deadline>, +) -> NewClient, RequestDispatch> +where + C: Transport, Response>, + Deadline: DelayQueueLike, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -254,7 +275,7 @@ where config, canceled_requests, transport: transport.fuse(), - in_flight_requests: InFlightRequests::default(), + in_flight_requests, pending_requests, terminal_error: None, }, @@ -266,7 +287,10 @@ where #[must_use] #[pin_project()] #[derive(Debug)] -pub struct RequestDispatch { +pub struct RequestDispatch +where + Deadline: DelayQueueLike, +{ /// Writes requests to the wire and reads responses off the wire. #[pin] transport: Fuse, @@ -275,7 +299,7 @@ pub struct RequestDispatch { /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, Deadline>, /// Configures limits to prevent unlimited resource usage. config: Config, /// Produces errors that can be sent in response to any unprocessed requests at the time @@ -285,13 +309,14 @@ pub struct RequestDispatch { terminal_error: Option>, } -impl RequestDispatch +impl RequestDispatch where C: Transport, Response>, + Deadline: DelayQueueLike, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + ) -> &'a mut InFlightRequests, Deadline> { self.as_mut().project().in_flight_requests } @@ -636,9 +661,10 @@ where } } -impl Future for RequestDispatch +impl Future for RequestDispatch where C: Transport, Response>, + Deadline: DelayQueueLike, { type Output = Result<(), ChannelError>; @@ -685,6 +711,7 @@ mod tests { client::{in_flight_requests::InFlightRequests, Config}, context::{self, current}, transport::{self, channel::UnboundedChannel}, + util::delay_queue::DelayQueue, ChannelError, ClientMessage, Response, }; use assert_matches::assert_matches; @@ -960,14 +987,14 @@ mod tests { fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, + Pin, DelayQueue>>>, Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { + let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, canceled_requests, @@ -1051,6 +1078,7 @@ mod tests { String, String, UnboundedChannel, ClientMessage>, + DelayQueue, >, >, >, @@ -1063,7 +1091,7 @@ mod tests { let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); - let dispatch = RequestDispatch:: { + let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests, canceled_requests, diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 1776a74a0..43658438a 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,24 +1,30 @@ use crate::{ context, - util::{Compact, TimeUntil}, + util::{delay_queue::DelayQueueLike, Compact, TimeUntil}, }; use fnv::FnvHashMap; use std::{ collections::hash_map, + fmt::Debug, task::{Context, Poll}, }; use tokio::sync::oneshot; -use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] -pub struct InFlightRequests { - request_data: FnvHashMap>, - deadlines: DelayQueue, +pub struct InFlightRequests +where + Deadline: DelayQueueLike, +{ + request_data: FnvHashMap>, + deadlines: Deadline, } -impl Default for InFlightRequests { +impl Default for InFlightRequests +where + Deadline: DelayQueueLike + Default, +{ fn default() -> Self { Self { request_data: Default::default(), @@ -28,12 +34,12 @@ impl Default for InFlightRequests { } #[derive(Debug)] -struct RequestData { +struct RequestData { ctx: context::Context, span: Span, response_completion: oneshot::Sender, /// The key to remove the timer for the request's deadline. - deadline_key: delay_queue::Key, + deadline_key: Key, } /// An error returned when an attempt is made to insert a request with an ID that is already in @@ -41,7 +47,10 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests +where + Deadline: DelayQueueLike, +{ /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -124,7 +133,7 @@ impl InFlightRequests { expired_error: impl Fn() -> Res, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { - let request_id = expired?.into_inner(); + let request_id = expired?; if let Some(request_data) = self.request_data.remove(&request_id) { let _entered = request_data.span.enter(); tracing::error!("DeadlineExceeded"); diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index f98e5791b..f3ff6b466 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -246,7 +246,9 @@ pub mod client; pub mod context; pub mod server; pub mod transport; -pub(crate) mod util; + +/// Utilities +pub mod util; pub use crate::transport::sealed::Transport; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d79d45c2c..1208eacca 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,6 +6,7 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. +use crate::util::delay_queue::{DelayQueue, DelayQueueLike}; use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, @@ -58,11 +59,15 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel( + self, + transport: T, + ) -> BaseChannelImpl where T: Transport, ClientMessage>, + Deadline: DelayQueueLike + Default, { - BaseChannel::new(self, transport) + BaseChannelImpl::new(self, transport) } } @@ -138,7 +143,10 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannelImpl +where + Deadline: DelayQueueLike, +{ config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -149,19 +157,23 @@ pub struct BaseChannel { /// Notifies `canceled_requests` when a request is canceled. request_cancellation: RequestCancellation, /// Holds data necessary to clean up in-flight requests. - in_flight_requests: InFlightRequests, + in_flight_requests: InFlightRequests, /// Types the request and response. ghost: PhantomData<(fn() -> Req, fn(Resp))>, } -impl BaseChannel +/// +pub type BaseChannel = BaseChannelImpl>; + +impl BaseChannelImpl where T: Transport, ClientMessage>, + Deadline: DelayQueueLike + Default, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { let (request_cancellation, canceled_requests) = cancellations(); - BaseChannel { + BaseChannelImpl { config, transport: transport.fuse(), canceled_requests, @@ -175,7 +187,13 @@ where pub fn with_defaults(transport: T) -> Self { Self::new(Config::default(), transport) } +} +impl BaseChannelImpl +where + T: Transport, ClientMessage>, + Deadline: DelayQueueLike, +{ /// Returns the inner transport over which messages are sent and received. pub fn get_ref(&self) -> &T { self.transport.get_ref() @@ -186,7 +204,9 @@ where self.project().transport.get_pin_mut() } - fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + fn in_flight_requests_mut<'a>( + self: &'a mut Pin<&mut Self>, + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -248,7 +268,10 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannelImpl +where + Deadline: DelayQueueLike, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -418,9 +441,10 @@ where } } -impl Stream for BaseChannel +impl Stream for BaseChannelImpl where T: Transport, ClientMessage>, + Deadline: DelayQueueLike, { type Item = Result, ChannelError>; @@ -525,10 +549,11 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannelImpl where T: Transport, ClientMessage>, T::Error: Error, + Deadline: DelayQueueLike, { type Error = ChannelError; @@ -572,15 +597,19 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannelImpl +where + Deadline: DelayQueueLike, +{ fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +impl Channel for BaseChannelImpl where T: Transport, ClientMessage>, + Deadline: DelayQueueLike, { type Req = Req; type Resp = Resp; diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index ec0c2ee9c..ce45c451e 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -1,4 +1,4 @@ -use crate::util::{Compact, TimeUntil}; +use crate::util::{delay_queue::DelayQueueLike, Compact, TimeUntil}; use fnv::FnvHashMap; use futures::future::{AbortHandle, AbortRegistration}; use std::{ @@ -6,24 +6,38 @@ use std::{ task::{Context, Poll}, time::Instant, }; -use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; /// A data structure that tracks in-flight requests. It aborts requests, /// either on demand or when a request deadline expires. -#[derive(Debug, Default)] -pub struct InFlightRequests { - request_data: FnvHashMap, - deadlines: DelayQueue, +#[derive(Debug)] +pub struct InFlightRequests +where + Deadline: DelayQueueLike, +{ + request_data: FnvHashMap>, + deadlines: Deadline, +} + +impl Default for InFlightRequests +where + Deadline: DelayQueueLike + Default, +{ + fn default() -> Self { + Self { + request_data: Default::default(), + deadlines: Default::default(), + } + } } /// Data needed to clean up a single in-flight request. #[derive(Debug)] -struct RequestData { +struct RequestData { /// Aborts the response handler for the associated request. abort_handle: AbortHandle, /// The key to remove the timer for the request's deadline. - deadline_key: delay_queue::Key, + deadline_key: Key, /// The client span. span: Span, } @@ -33,7 +47,10 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests +where + Deadline: DelayQueueLike, +{ /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -104,20 +121,23 @@ impl InFlightRequests { let expired = expired?; if let Some(RequestData { abort_handle, span, .. - }) = self.request_data.remove(expired.get_ref()) + }) = self.request_data.remove(&expired) { let _entered = span.enter(); self.request_data.compact(0.1); abort_handle.abort(); tracing::error!("DeadlineExceeded"); } - Some(expired.into_inner()) + Some(expired) }) } } /// When InFlightRequests is dropped, any outstanding requests are aborted. -impl Drop for InFlightRequests { +impl Drop for InFlightRequests +where + Deadline: DelayQueueLike, +{ fn drop(&mut self) { self.request_data .values() @@ -129,6 +149,7 @@ impl Drop for InFlightRequests { mod tests { use super::*; + use crate::util::delay_queue::DelayQueue; use assert_matches::assert_matches; use futures::{ future::{pending, Abortable}, @@ -138,7 +159,7 @@ mod tests { #[tokio::test] async fn start_request_increases_len() { - let mut in_flight_requests = InFlightRequests::default(); + let mut in_flight_requests = InFlightRequests::>::default(); assert_eq!(in_flight_requests.len(), 0); in_flight_requests .start_request(0, Instant::now(), Span::current()) @@ -148,7 +169,7 @@ mod tests { #[tokio::test] async fn polling_expired_aborts() { - let mut in_flight_requests = InFlightRequests::default(); + let mut in_flight_requests = InFlightRequests::>::default(); let abort_registration = in_flight_requests .start_request(0, Instant::now(), Span::current()) .unwrap(); @@ -170,7 +191,7 @@ mod tests { #[tokio::test] async fn cancel_request_aborts() { - let mut in_flight_requests = InFlightRequests::default(); + let mut in_flight_requests = InFlightRequests::>::default(); let abort_registration = in_flight_requests .start_request(0, Instant::now(), Span::current()) .unwrap(); @@ -186,7 +207,7 @@ mod tests { #[tokio::test] async fn remove_request_doesnt_abort() { - let mut in_flight_requests = InFlightRequests::default(); + let mut in_flight_requests = InFlightRequests::>::default(); assert!(in_flight_requests.deadlines.is_empty()); let abort_registration = in_flight_requests diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 2ec96da3b..04dec8716 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -8,6 +8,7 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, server::{Channel, Config, ResponseGuard, TrackedRequest}, + util::delay_queue::{DelayQueue, DelayQueueLike}, Request, Response, }; use futures::{task::*, Sink, Stream}; @@ -16,20 +17,24 @@ use std::{collections::VecDeque, io, pin::Pin, time::Instant}; use tracing::Span; #[pin_project] -pub(crate) struct FakeChannel { +pub(crate) struct FakeChannel +where + Deadline: DelayQueueLike, +{ #[pin] pub stream: VecDeque, #[pin] pub sink: VecDeque, pub config: Config, - pub in_flight_requests: super::in_flight_requests::InFlightRequests, + pub in_flight_requests: super::in_flight_requests::InFlightRequests, pub request_cancellation: RequestCancellation, pub canceled_requests: CanceledRequests, } -impl Stream for FakeChannel +impl Stream for FakeChannel where In: Unpin, + Deadline: DelayQueueLike, { type Item = In; @@ -38,7 +43,10 @@ where } } -impl Sink> for FakeChannel> { +impl Sink> for FakeChannel, Deadline> +where + Deadline: DelayQueueLike, +{ type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -65,9 +73,11 @@ impl Sink> for FakeChannel> { } } -impl Channel for FakeChannel>, Response> +impl Channel + for FakeChannel>, Response, Deadline> where Req: Unpin, + Deadline: DelayQueueLike, { type Req = Req; type Resp = Resp; @@ -86,7 +96,10 @@ where } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response, Deadline> +where + Deadline: DelayQueueLike, +{ pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); @@ -110,8 +123,9 @@ impl FakeChannel>, Response> { } } -impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> { +impl FakeChannel<(), (), DelayQueue> { + pub fn default( + ) -> FakeChannel>, Response, DelayQueue> { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), diff --git a/tarpc/src/util.rs b/tarpc/src/util.rs index 101813c0e..8dede3e7b 100644 --- a/tarpc/src/util.rs +++ b/tarpc/src/util.rs @@ -12,7 +12,7 @@ use std::{ #[cfg(feature = "serde1")] #[cfg_attr(docsrs, doc(cfg(feature = "serde1")))] -pub mod serde; +pub(crate) mod serde; /// Extension trait for [Instants](Instant) in the future, i.e. deadlines. pub trait TimeUntil { @@ -26,6 +26,9 @@ impl TimeUntil for Instant { } } +/// Utility for delay queue +pub mod delay_queue; + /// Collection compaction; configurable `shrink_to_fit`. pub trait Compact { /// Compacts space if the ratio of length : capacity is less than `usage_ratio_threshold`. diff --git a/tarpc/src/util/delay_queue.rs b/tarpc/src/util/delay_queue.rs new file mode 100644 index 000000000..6372c0b4d --- /dev/null +++ b/tarpc/src/util/delay_queue.rs @@ -0,0 +1,78 @@ +use std::{ + fmt::Debug, + task::{Context, Poll}, + time::Duration, +}; + +pub use tokio_util::time::DelayQueue; + +/// A trait that mocks [`DelayQueue`] with the minimal set of APIs that are needed for tarpc to run +/// This is needed so that user can supply their own implementation that satisfy the behavior of [`DelayQueue`] +/// So that the user can go runtime-agnostic on timer-related stuff (for example to run on smol or embassy) +pub trait DelayQueueLike: Debug { + /// The key returned from queue insertion, as a token to control the delay + type Key: Debug; + /// Inserts `value` into the queue set to expire after the requested duration + /// elapses. + /// + /// `value` is stored in the queue until `timeout` duration has + /// elapsed after `insert` was called. At that point, `value` will + /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of + /// zero, then `value` is immediately made available to poll. + /// + /// The return value represents the insertion and is used as an + /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a + /// token and is reused once `value` is removed from the queue + /// either by calling [`poll_expired`] after `timeout` has elapsed + /// or by calling [`remove`]. At this point, the caller must not + /// use the returned [`Key`] again as it may reference a different + /// item in the queue. + fn insert(&mut self, value: T, timeout: Duration) -> Self::Key; + /// Removes the item associated with `key` from the queue. + /// + /// There must be an item associated with `key`. The function returns the + /// removed item as well as the `Instant` at which it will the delay will + /// have expired. + fn remove(&mut self, key: &Self::Key); + /// Clears the queue, removing all items. + /// + /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`. + /// + /// Note that this method has no effect on the allocated capacity. + fn clear(&mut self); + /// Attempts to pull out the next value of the delay queue, registering the + /// current task for wakeup if the value is not yet available, and returning + /// `None` if the queue is exhausted. + fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll>; + /// Returns `true` if there are no items in the queue. + /// + /// Note that this function returns `false` even if all items have not yet + /// expired and a call to `poll` will return `Poll::Pending`. + fn is_empty(&self) -> bool; +} + +impl DelayQueueLike for DelayQueue { + type Key = tokio_util::time::delay_queue::Key; + + fn insert(&mut self, value: T, timeout: Duration) -> Self::Key { + (self as &mut DelayQueue).insert(value, timeout) + } + + fn remove(&mut self, key: &Self::Key) { + (self as &mut DelayQueue).remove(key); + } + + fn clear(&mut self) { + (self as &mut DelayQueue).clear(); + } + + fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll> { + (self as &mut DelayQueue) + .poll_expired(cx) + .map(|f| f.map(|x| x.into_inner())) + } + + fn is_empty(&self) -> bool { + (self as &DelayQueue).is_empty() + } +}