diff --git a/Cargo.lock b/Cargo.lock index 60f8b544..93463b93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2999,6 +2999,7 @@ dependencies = [ "rstest 0.26.1", "serde", "serial_test", + "socket2 0.6.0", "static_assertions", "thiserror 2.0.16", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 4a0eaf5b..0a6569c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ metrics = { version = "0.24.2", optional = true } thiserror = "2.0.16" static_assertions = "1.1.0" derive_more = { version = "2.0.1", features = ["display", "from"] } +socket2 = "0.6.0" [dev-dependencies] rstest = "0.26.1" diff --git a/docs/roadmap.md b/docs/roadmap.md index 60f78e13..d34415fc 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -287,12 +287,12 @@ logic. ## Phase 8: Wireframe client library foundation This phase delivers a first-class client runtime that mirrors the server's -framing, serialisation, and lifecycle layers so both sides share the same +framing, serialisation, and lifecycle layers, so both sides share the same behavioural guarantees. - [ ] **Connection runtime:** - - [ ] Implement `WireframeClient` and its builder so callers can configure + - [x] Implement `WireframeClient` and its builder so callers can configure serializers, codec settings (including `max_frame_length` parity), and socket options before connecting. diff --git a/docs/users-guide.md b/docs/users-guide.md index f9463a54..490c7a4c 100644 --- a/docs/users-guide.md +++ b/docs/users-guide.md @@ -12,7 +12,7 @@ as an `Arc` pointing to an async function that receives a packet reference and returns `()`. The builder caches these registrations until `handle_connection` constructs the middleware chain for an accepted stream.[^2] -```rust +```no_run use std::sync::Arc; use wireframe::app::{Envelope, Handler, WireframeApp}; @@ -99,12 +99,12 @@ async fn main() -> Result<(), ServerError> { ``` Route identifiers must be unique; the builder returns -`WireframeError::DuplicateRoute` when you try to register a handler twice, +`WireframeError::DuplicateRoute` when a handler is registered twice, keeping the dispatch table unambiguous.[^2][^5] New applications default to the bundled bincode serializer, a 1024-byte frame buffer, and a 100 ms read timeout. Clamp these limits with `buffer_capacity` and `read_timeout_ms`, or -swap the serializer with `with_serializer` when you need a different encoding -strategy.[^3][^4] +swap the serializer with `with_serializer` when a different encoding strategy +is required.[^3][^4] Once a stream is accepted—either from a manual accept loop or via `WireframeServer`—`handle_connection(stream)` builds (or reuses) the middleware @@ -157,7 +157,7 @@ layer evolves. The helper is fallible—`FragmentationError` surfaces encoding failures or index overflows—so production code should bubble the error up or log it rather than unwrapping. -```rust +```no_run use std::num::NonZeroUsize; use wireframe::fragment::Fragmenter; @@ -378,6 +378,52 @@ the failure callback path.[^20] worker tasks.[^20][^37][^38] `ServerError` surfaces bind and accept failures as typed errors so callers can react appropriately.[^21] +## Client runtime + +`WireframeClient` provides a first-class client runtime that mirrors the +server's framing and serialization layers, with a builder that configures the +serializer, codec settings, and socket options before connecting.[^44] Use +`ClientCodecConfig` to align `max_frame_length` with the server's +`buffer_capacity`, and apply `SocketOptions` when TCP tuning is required, such +as `TCP_NODELAY` or buffer size adjustments. + +```rust +use std::{net::SocketAddr, time::Duration}; + +use wireframe::{ + client::{ClientCodecConfig, SocketOptions}, + WireframeClient, +}; + +#[derive(bincode::Encode, bincode::BorrowDecode)] +struct Login { + username: String, +} + +#[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq)] +struct LoginAck { + ok: bool, +} + +let addr: SocketAddr = "127.0.0.1:7878".parse().expect("valid socket address"); +let codec = ClientCodecConfig::default().max_frame_length(2048); +let socket = SocketOptions::default() + .nodelay(true) + .keepalive(Some(Duration::from_secs(30))); + +let mut client = WireframeClient::builder() + .codec_config(codec) + .socket_options(socket) + .connect(addr) + .await?; + +let login = Login { + username: "guest".to_string(), +}; +let ack: LoginAck = client.call(&login).await?; +assert!(ack.ok); +``` + ## Push queues and connection actors Background work interacts with connections through `PushQueues`. The fluent @@ -623,3 +669,5 @@ call these helpers to maintain consistent telemetry.[^6][^7][^31][^20] [^41]: Implemented in `src/fragment/mod.rs` and supporting submodules. [^42]: Exercised in `tests/features/fragment.feature`. [^43]: Step definitions in `tests/steps/fragment_steps.rs`. +[^44]: Implemented in `src/client/runtime.rs`, `src/client/builder.rs`, + `src/client/config.rs`, and `src/client/error.rs`. diff --git a/docs/wireframe-client-design.md b/docs/wireframe-client-design.md index 34f0d981..265c2511 100644 --- a/docs/wireframe-client-design.md +++ b/docs/wireframe-client-design.md @@ -21,13 +21,13 @@ implementation of a lightweight client without duplicating protocol code. A new `WireframeClient` type manages a single connection to a server. It mirrors `WireframeServer` but operates in the opposite direction: -- Connect to a `TcpStream`. +- Connect to a `TcpStream`, applying `SocketOptions` before the handshake. - Optionally, send a preamble using the existing `Preamble` helpers. - Encode outgoing messages using the selected `Serializer` and `tokio_util::codec::LengthDelimitedCodec` (4‑byte big‑endian prefix by default; configurable). Configure the codec’s `max_frame_length` on both the inbound (decode) and outbound (encode) paths to match the server’s frame - capacity; otherwise, frames larger than the default 8 MiB will fail. + capacity; otherwise, frames larger than the configured limit will fail. - Decode incoming frames into typed responses. - Expose async `send` and `receive` operations. @@ -36,9 +36,15 @@ mirrors `WireframeServer` but operates in the opposite direction: A `WireframeClient::builder()` method configures the client: ```rust +use std::net::SocketAddr; + +use wireframe::{BincodeSerializer, WireframeClient}; + +let addr: SocketAddr = "127.0.0.1:7878".parse()?; let client = WireframeClient::builder() .serializer(BincodeSerializer) - .connect("127.0.0.1:7878") + .max_frame_length(1024) + .connect(addr) .await?; ``` @@ -53,13 +59,22 @@ message implementing `Message` and waits for the next response frame: ```rust let request = Login { username: "guest".into() }; -let response: LoginAck = client.call(request).await?; +let response: LoginAck = client.call(&request).await?; ``` Internally, this uses the `Serializer` to encode the request, sends it through the length‑delimited codec, then waits for a frame, decodes it, and deserializes the response type. +### Implementation decisions + +- `connect` accepts a `SocketAddr` so the client can create a `TcpSocket` and + apply socket options before connecting. +- `ClientCodecConfig` captures the length prefix format and maximum frame + length, clamping the frame length to match server bounds (64 bytes to 16 MiB). +- The default `max_frame_length` is 1024 bytes to mirror the server builder’s + default buffer capacity. + ### Connection lifecycle Like the server, the client should expose hooks for connection setup and @@ -71,13 +86,16 @@ share initialization logic. ```rust #[tokio::main] async fn main() -> std::io::Result<()> { + use std::net::SocketAddr; + let mut client = WireframeClient::builder() .serializer(BincodeSerializer) - .connect("127.0.0.1:7878") + .max_frame_length(1024) + .connect("127.0.0.1:7878".parse::()?) .await?; let login = Login { username: "guest".into() }; - let ack: LoginAck = client.call(login).await?; + let ack: LoginAck = client.call(&login).await?; println!("logged in: {:?}", ack); Ok(()) } diff --git a/src/client/builder.rs b/src/client/builder.rs new file mode 100644 index 00000000..97d203d9 --- /dev/null +++ b/src/client/builder.rs @@ -0,0 +1,310 @@ +//! Builder for configuring and connecting a wireframe client. + +use std::{net::SocketAddr, time::Duration}; + +use tokio::net::TcpSocket; +use tokio_util::codec::Framed; + +use super::{ClientCodecConfig, ClientError, SocketOptions, WireframeClient}; +use crate::{ + frame::LengthFormat, + serializer::{BincodeSerializer, Serializer}, +}; + +/// Builder for [`WireframeClient`]. +/// +/// # Examples +/// +/// ``` +/// use wireframe::client::WireframeClientBuilder; +/// +/// let builder = WireframeClientBuilder::new(); +/// let _ = builder; +/// ``` +pub struct WireframeClientBuilder { + serializer: S, + codec_config: ClientCodecConfig, + socket_options: SocketOptions, +} + +impl WireframeClientBuilder { + /// Create a new builder with default settings. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new(); + /// let _ = builder; + /// ``` + #[must_use] + pub fn new() -> Self { + Self { + serializer: BincodeSerializer, + codec_config: ClientCodecConfig::default(), + socket_options: SocketOptions::default(), + } + } +} + +impl Default for WireframeClientBuilder { + fn default() -> Self { Self::new() } +} + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Replace the serializer used for encoding and decoding messages. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{BincodeSerializer, client::WireframeClientBuilder}; + /// + /// let builder = WireframeClientBuilder::new().serializer(BincodeSerializer); + /// let _ = builder; + /// ``` + #[must_use] + pub fn serializer(self, serializer: Ser) -> WireframeClientBuilder + where + Ser: Serializer + Send + Sync, + { + WireframeClientBuilder { + serializer, + codec_config: self.codec_config, + socket_options: self.socket_options, + } + } + + /// Configure codec settings for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::{ClientCodecConfig, WireframeClientBuilder}; + /// + /// let codec = ClientCodecConfig::default().max_frame_length(2048); + /// let builder = WireframeClientBuilder::new().codec_config(codec); + /// let _ = builder; + /// ``` + #[must_use] + pub fn codec_config(mut self, codec_config: ClientCodecConfig) -> Self { + self.codec_config = codec_config; + self + } + + /// Configure the maximum frame length for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().max_frame_length(2048); + /// let _ = builder; + /// ``` + #[must_use] + pub fn max_frame_length(mut self, max_frame_length: usize) -> Self { + self.codec_config = self.codec_config.max_frame_length(max_frame_length); + self + } + + /// Configure the length prefix format for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{client::WireframeClientBuilder, frame::LengthFormat}; + /// + /// let builder = WireframeClientBuilder::new().length_format(LengthFormat::u16_be()); + /// let _ = builder; + /// ``` + #[must_use] + pub fn length_format(mut self, length_format: LengthFormat) -> Self { + self.codec_config = self.codec_config.length_format(length_format); + self + } + + /// Replace the socket options applied before connecting. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::{SocketOptions, WireframeClientBuilder}; + /// + /// let options = SocketOptions::default().nodelay(true); + /// let builder = WireframeClientBuilder::new().socket_options(options); + /// let _ = builder; + /// ``` + #[must_use] + pub fn socket_options(mut self, socket_options: SocketOptions) -> Self { + self.socket_options = socket_options; + self + } + + /// Configure `TCP_NODELAY` for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().nodelay(true); + /// let _ = builder; + /// ``` + #[must_use] + pub fn nodelay(mut self, enabled: bool) -> Self { + self.socket_options = self.socket_options.nodelay(enabled); + self + } + + /// Configure `SO_KEEPALIVE` for the connection. + /// + /// # Examples + /// + /// ``` + /// use std::time::Duration; + /// + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().keepalive(Some(Duration::from_secs(30))); + /// let _ = builder; + /// ``` + #[must_use] + pub fn keepalive(mut self, duration: Option) -> Self { + self.socket_options = self.socket_options.keepalive(duration); + self + } + + /// Configure TCP linger behaviour for the connection. + /// + /// # Examples + /// + /// ``` + /// use std::time::Duration; + /// + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().linger(Some(Duration::from_secs(1))); + /// let _ = builder; + /// ``` + #[must_use] + pub fn linger(mut self, duration: Option) -> Self { + self.socket_options = self.socket_options.linger(duration); + self + } + + /// Configure the socket send buffer size. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().send_buffer_size(4096); + /// let _ = builder; + /// ``` + #[must_use] + pub fn send_buffer_size(mut self, size: u32) -> Self { + self.socket_options = self.socket_options.send_buffer_size(size); + self + } + + /// Configure the socket receive buffer size. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().recv_buffer_size(4096); + /// let _ = builder; + /// ``` + #[must_use] + pub fn recv_buffer_size(mut self, size: u32) -> Self { + self.socket_options = self.socket_options.recv_buffer_size(size); + self + } + + /// Configure `SO_REUSEADDR` for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().reuseaddr(true); + /// let _ = builder; + /// ``` + #[must_use] + pub fn reuseaddr(mut self, enabled: bool) -> Self { + self.socket_options = self.socket_options.reuseaddr(enabled); + self + } + + /// Configure `SO_REUSEPORT` for the connection on supported platforms. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().reuseport(true); + /// let _ = builder; + /// ``` + #[cfg(all( + unix, + not(target_os = "solaris"), + not(target_os = "illumos"), + not(target_os = "cygwin"), + ))] + #[must_use] + pub fn reuseport(mut self, enabled: bool) -> Self { + self.socket_options = self.socket_options.reuseport(enabled); + self + } + + /// Establish a connection and return a configured client. + /// + /// # Errors + /// Returns [`ClientError`] if socket configuration or connection fails. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let _client = WireframeClient::builder().connect(addr).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn connect(self, addr: SocketAddr) -> Result, ClientError> { + let socket = if addr.is_ipv4() { + TcpSocket::new_v4()? + } else { + TcpSocket::new_v6()? + }; + self.socket_options.apply(&socket)?; + let stream = socket.connect(addr).await?; + let codec_config = self.codec_config; + let codec = codec_config.build_codec(); + let mut framed = Framed::new(stream, codec); + let initial_read_buffer_capacity = + core::cmp::min(64 * 1024, codec_config.max_frame_length_value()); + framed + .read_buffer_mut() + .reserve(initial_read_buffer_capacity); + Ok(WireframeClient { + framed, + serializer: self.serializer, + codec_config, + }) + } +} diff --git a/src/client/codec_config.rs b/src/client/codec_config.rs new file mode 100644 index 00000000..548fb92f --- /dev/null +++ b/src/client/codec_config.rs @@ -0,0 +1,116 @@ +//! Codec configuration for wireframe clients. + +use tokio_util::codec::LengthDelimitedCodec; + +use crate::frame::{Endianness, LengthFormat}; + +const MIN_FRAME_LENGTH: usize = 64; +const MAX_FRAME_LENGTH: usize = 16 * 1024 * 1024; +const DEFAULT_MAX_FRAME_LENGTH: usize = 1024; + +/// Codec configuration for the wireframe client. +/// +/// # Examples +/// +/// ``` +/// use wireframe::client::ClientCodecConfig; +/// +/// let codec = ClientCodecConfig::default().max_frame_length(2048); +/// assert_eq!(codec.max_frame_length_value(), 2048); +/// ``` +#[derive(Clone, Copy, Debug)] +pub struct ClientCodecConfig { + length_format: LengthFormat, + max_frame_length: usize, +} + +impl Default for ClientCodecConfig { + fn default() -> Self { + Self { + length_format: LengthFormat::default(), + max_frame_length: DEFAULT_MAX_FRAME_LENGTH, + } + } +} + +impl ClientCodecConfig { + /// Set the maximum frame length for encoding and decoding. + /// + /// The value is clamped between 64 bytes and 16 MiB. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::ClientCodecConfig; + /// + /// let codec = ClientCodecConfig::default().max_frame_length(2048); + /// assert_eq!(codec.max_frame_length_value(), 2048); + /// ``` + #[must_use] + pub fn max_frame_length(mut self, max_frame_length: usize) -> Self { + self.max_frame_length = max_frame_length.clamp(MIN_FRAME_LENGTH, MAX_FRAME_LENGTH); + self + } + + /// Set the length prefix format used by the codec. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{ + /// client::ClientCodecConfig, + /// frame::{Endianness, LengthFormat}, + /// }; + /// + /// let codec = ClientCodecConfig::default().length_format(LengthFormat::u16_le()); + /// assert_eq!(codec.length_format_value().bytes(), 2); + /// assert_eq!(codec.length_format_value().endianness(), Endianness::Little); + /// ``` + #[must_use] + pub fn length_format(mut self, length_format: LengthFormat) -> Self { + self.length_format = length_format; + self + } + + /// Return the configured maximum frame length. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::ClientCodecConfig; + /// + /// let codec = ClientCodecConfig::default(); + /// assert_eq!(codec.max_frame_length_value(), 1024); + /// ``` + #[must_use] + pub const fn max_frame_length_value(&self) -> usize { self.max_frame_length } + + /// Return the configured length prefix format. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{client::ClientCodecConfig, frame::Endianness}; + /// + /// let codec = ClientCodecConfig::default(); + /// assert_eq!(codec.length_format_value().bytes(), 4); + /// assert_eq!(codec.length_format_value().endianness(), Endianness::Big); + /// ``` + #[must_use] + pub const fn length_format_value(&self) -> LengthFormat { self.length_format } + + pub(crate) fn build_codec(&self) -> LengthDelimitedCodec { + let mut builder = LengthDelimitedCodec::builder(); + builder.length_field_length(self.length_format.bytes()); + match self.length_format.endianness() { + Endianness::Big => { + builder.big_endian(); + } + Endianness::Little => { + builder.little_endian(); + } + } + builder.max_frame_length(self.max_frame_length); + builder.new_codec() + } +} diff --git a/src/client/config.rs b/src/client/config.rs new file mode 100644 index 00000000..a66f264c --- /dev/null +++ b/src/client/config.rs @@ -0,0 +1,291 @@ +//! Socket options for wireframe clients. + +use std::{io, time::Duration}; + +use socket2::{SockRef, TcpKeepalive}; +use tokio::net::TcpSocket; + +/// Socket options applied before connecting a client. +/// +/// # Examples +/// +/// ``` +/// use std::time::Duration; +/// +/// use wireframe::client::SocketOptions; +/// +/// let options = SocketOptions::default() +/// .nodelay(true) +/// .keepalive(Some(Duration::from_secs(30))); +/// let expected = SocketOptions::default() +/// .nodelay(true) +/// .keepalive(Some(Duration::from_secs(30))); +/// assert_eq!(options, expected); +/// ``` +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct SocketOptions { + nodelay: Option, + keepalive: Option, + linger: Option, + send_buffer_size: Option, + recv_buffer_size: Option, + reuseaddr: Option, + #[cfg(all( + unix, + not(target_os = "solaris"), + not(target_os = "illumos"), + not(target_os = "cygwin"), + ))] + reuseport: Option, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum LingerSetting { + Disabled, + Duration(Duration), +} + +impl LingerSetting { + const fn to_option(self) -> Option { + match self { + Self::Disabled => None, + Self::Duration(value) => Some(value), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum KeepAliveSetting { + Disabled, + Duration(Duration), +} + +impl KeepAliveSetting { + const fn to_option(self) -> Option { + match self { + Self::Disabled => None, + Self::Duration(value) => Some(value), + } + } +} + +impl SocketOptions { + /// Configure `TCP_NODELAY` behaviour on the socket. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::SocketOptions; + /// + /// let options = SocketOptions::default().nodelay(true); + /// let expected = SocketOptions::default().nodelay(true); + /// assert_eq!(options, expected); + /// ``` + #[must_use] + pub fn nodelay(mut self, enabled: bool) -> Self { + self.nodelay = Some(enabled); + self + } + + /// Configure `SO_KEEPALIVE` behaviour on the socket. + /// + /// # Examples + /// + /// ``` + /// use std::time::Duration; + /// + /// use wireframe::client::SocketOptions; + /// + /// let options = SocketOptions::default().keepalive(Some(Duration::from_secs(30))); + /// let expected = SocketOptions::default().keepalive(Some(Duration::from_secs(30))); + /// assert_eq!(options, expected); + /// ``` + #[must_use] + pub fn keepalive(mut self, duration: Option) -> Self { + self.keepalive = Some(match duration { + Some(value) => KeepAliveSetting::Duration(value), + None => KeepAliveSetting::Disabled, + }); + self + } + + /// Configure TCP linger settings on the socket. + /// + /// # Examples + /// + /// ``` + /// use std::time::Duration; + /// + /// use wireframe::client::SocketOptions; + /// + /// let options = SocketOptions::default().linger(Some(Duration::from_secs(1))); + /// let expected = SocketOptions::default().linger(Some(Duration::from_secs(1))); + /// assert_eq!(options, expected); + /// ``` + #[must_use] + pub fn linger(mut self, duration: Option) -> Self { + self.linger = Some(match duration { + Some(value) => LingerSetting::Duration(value), + None => LingerSetting::Disabled, + }); + self + } + + /// Configure the socket send buffer size. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::SocketOptions; + /// + /// let options = SocketOptions::default().send_buffer_size(4096); + /// let expected = SocketOptions::default().send_buffer_size(4096); + /// assert_eq!(options, expected); + /// ``` + #[must_use] + pub fn send_buffer_size(mut self, size: u32) -> Self { + self.send_buffer_size = Some(size); + self + } + + /// Configure the socket receive buffer size. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::SocketOptions; + /// + /// let options = SocketOptions::default().recv_buffer_size(4096); + /// let expected = SocketOptions::default().recv_buffer_size(4096); + /// assert_eq!(options, expected); + /// ``` + #[must_use] + pub fn recv_buffer_size(mut self, size: u32) -> Self { + self.recv_buffer_size = Some(size); + self + } + + /// Configure `SO_REUSEADDR` behaviour on the socket. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::SocketOptions; + /// + /// let options = SocketOptions::default().reuseaddr(true); + /// let expected = SocketOptions::default().reuseaddr(true); + /// assert_eq!(options, expected); + /// ``` + #[must_use] + pub fn reuseaddr(mut self, enabled: bool) -> Self { + self.reuseaddr = Some(enabled); + self + } + + /// Configure `SO_REUSEPORT` behaviour on supported platforms. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::SocketOptions; + /// + /// let options = SocketOptions::default().reuseport(true); + /// let expected = SocketOptions::default().reuseport(true); + /// assert_eq!(options, expected); + /// ``` + #[cfg(all( + unix, + not(target_os = "solaris"), + not(target_os = "illumos"), + not(target_os = "cygwin"), + ))] + #[must_use] + pub fn reuseport(mut self, enabled: bool) -> Self { + self.reuseport = Some(enabled); + self + } + + pub(crate) fn apply(&self, socket: &TcpSocket) -> io::Result<()> { + self.apply_nodelay(socket)?; + self.apply_keepalive(socket)?; + self.apply_linger(socket)?; + self.apply_send_buffer_size(socket)?; + self.apply_recv_buffer_size(socket)?; + self.apply_reuseaddr(socket)?; + self.apply_reuseport(socket)?; + Ok(()) + } + + fn apply_nodelay(&self, socket: &TcpSocket) -> io::Result<()> { + if let Some(enabled) = self.nodelay { + socket.set_nodelay(enabled)?; + } + Ok(()) + } + + fn apply_keepalive(&self, socket: &TcpSocket) -> io::Result<()> { + if let Some(keepalive) = self.keepalive { + match keepalive.to_option() { + Some(duration) => { + socket.set_keepalive(true)?; + let sock_ref = SockRef::from(socket); + let config = TcpKeepalive::new().with_time(duration); + sock_ref.set_tcp_keepalive(&config)?; + } + None => { + socket.set_keepalive(false)?; + } + } + } + Ok(()) + } + + fn apply_linger(&self, socket: &TcpSocket) -> io::Result<()> { + if let Some(linger) = self.linger { + socket.set_linger(linger.to_option())?; + } + Ok(()) + } + + fn apply_send_buffer_size(&self, socket: &TcpSocket) -> io::Result<()> { + if let Some(size) = self.send_buffer_size { + socket.set_send_buffer_size(size)?; + } + Ok(()) + } + + fn apply_recv_buffer_size(&self, socket: &TcpSocket) -> io::Result<()> { + if let Some(size) = self.recv_buffer_size { + socket.set_recv_buffer_size(size)?; + } + Ok(()) + } + + fn apply_reuseaddr(&self, socket: &TcpSocket) -> io::Result<()> { + if let Some(enabled) = self.reuseaddr { + socket.set_reuseaddr(enabled)?; + } + Ok(()) + } + + #[cfg(all( + unix, + not(target_os = "solaris"), + not(target_os = "illumos"), + not(target_os = "cygwin"), + ))] + fn apply_reuseport(&self, socket: &TcpSocket) -> io::Result<()> { + if let Some(enabled) = self.reuseport { + socket.set_reuseport(enabled)?; + } + Ok(()) + } + + #[cfg(not(all( + unix, + not(target_os = "solaris"), + not(target_os = "illumos"), + not(target_os = "cygwin"), + )))] + fn apply_reuseport(&self, _socket: &TcpSocket) -> io::Result<()> { Ok(()) } +} diff --git a/src/client/error.rs b/src/client/error.rs new file mode 100644 index 00000000..6058a948 --- /dev/null +++ b/src/client/error.rs @@ -0,0 +1,20 @@ +//! Error types for wireframe client operations. + +use std::io; + +/// Errors emitted by [`crate::WireframeClient`]. +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + /// Transport or codec error. + #[error("transport error: {0}")] + Io(#[from] io::Error), + /// Failed to serialize an outbound message. + #[error("failed to serialize message")] + Serialize(#[source] Box), + /// Failed to deserialize an inbound message. + #[error("failed to deserialize message")] + Deserialize(#[source] Box), + /// The peer closed the connection before a response arrived. + #[error("connection closed by peer")] + Disconnected, +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 00000000..267d61a2 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,19 @@ +//! Client runtime for wireframe connections. +//! +//! This module provides a configurable client runtime that mirrors the +//! server's framing and serialization layers. + +mod builder; +mod codec_config; +mod config; +mod error; +mod runtime; + +pub use builder::WireframeClientBuilder; +pub use codec_config::ClientCodecConfig; +pub use config::SocketOptions; +pub use error::ClientError; +pub use runtime::WireframeClient; + +#[cfg(test)] +mod tests; diff --git a/src/client/runtime.rs b/src/client/runtime.rs new file mode 100644 index 00000000..60e25a0d --- /dev/null +++ b/src/client/runtime.rs @@ -0,0 +1,221 @@ +//! Wireframe client runtime implementation. + +use std::fmt; + +use bytes::Bytes; +use futures::{SinkExt, StreamExt}; +use tokio::net::TcpStream; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; + +use super::{ClientCodecConfig, ClientError, WireframeClientBuilder}; +use crate::{ + message::Message, + serializer::{BincodeSerializer, Serializer}, +}; + +/// Client runtime for wireframe connections. +/// +/// # Examples +/// +/// ```no_run +/// use std::net::SocketAddr; +/// +/// use wireframe::WireframeClient; +/// +/// # #[tokio::main] +/// # async fn main() -> Result<(), wireframe::ClientError> { +/// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); +/// let _client = WireframeClient::builder().connect(addr).await?; +/// # Ok(()) +/// # } +/// ``` +pub struct WireframeClient { + pub(crate) framed: Framed, + pub(crate) serializer: S, + pub(crate) codec_config: ClientCodecConfig, +} + +impl fmt::Debug for WireframeClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WireframeClient") + .field("codec_config", &self.codec_config) + .finish_non_exhaustive() + } +} + +impl WireframeClient { + /// Start building a new client with the default serializer and codec. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::WireframeClient; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), wireframe::ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let _client = WireframeClient::builder().connect(addr).await?; + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn builder() -> WireframeClientBuilder { WireframeClientBuilder::new() } +} + +impl WireframeClient +where + S: Serializer + Send + Sync, +{ + /// Send a message to the peer using the configured serializer. + /// + /// # Errors + /// Returns [`ClientError`] if serialization or I/O fails. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// #[derive(bincode::Encode, bincode::BorrowDecode)] + /// struct Ping(u8); + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let mut client = WireframeClient::builder().connect(addr).await?; + /// client.send(&Ping(1)).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn send(&mut self, message: &M) -> Result<(), ClientError> { + let bytes = self + .serializer + .serialize(message) + .map_err(ClientError::Serialize)?; + self.framed.send(Bytes::from(bytes)).await?; + Ok(()) + } + + /// Receive the next message from the peer. + /// + /// # Errors + /// Returns [`ClientError`] if the connection closes, decoding fails, or I/O + /// errors occur. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// #[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq)] + /// struct Pong(u8); + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let mut client = WireframeClient::builder().connect(addr).await?; + /// let _pong: Pong = client.receive().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn receive(&mut self) -> Result { + let Some(frame) = self.framed.next().await else { + return Err(ClientError::Disconnected); + }; + let bytes = frame?; + let (message, _consumed) = self + .serializer + .deserialize(&bytes) + .map_err(ClientError::Deserialize)?; + Ok(message) + } + + /// Send a message and await the next response. + /// + /// # Errors + /// Returns [`ClientError`] if the request cannot be sent or the response + /// cannot be decoded. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// #[derive(bincode::Encode, bincode::BorrowDecode)] + /// struct Login { + /// username: String, + /// } + /// + /// #[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq)] + /// struct LoginAck { + /// ok: bool, + /// } + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let mut client = WireframeClient::builder().connect(addr).await?; + /// let login = Login { + /// username: "guest".to_string(), + /// }; + /// let _ack: LoginAck = client.call(&login).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn call( + &mut self, + request: &Req, + ) -> Result { + self.send(request).await?; + self.receive().await + } + + /// Inspect the configured codec settings. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let client = WireframeClient::builder().connect(addr).await?; + /// let codec = client.codec_config(); + /// assert_eq!(codec.max_frame_length_value(), 1024); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub const fn codec_config(&self) -> &ClientCodecConfig { &self.codec_config } + + /// Access the underlying [`TcpStream`]. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let client = WireframeClient::builder().connect(addr).await?; + /// let _stream = client.tcp_stream(); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn tcp_stream(&self) -> &TcpStream { self.framed.get_ref() } +} diff --git a/src/client/tests.rs b/src/client/tests.rs new file mode 100644 index 00000000..7f1db4ee --- /dev/null +++ b/src/client/tests.rs @@ -0,0 +1,204 @@ +//! Unit tests for the wireframe client runtime. + +use std::{net::SocketAddr, time::Duration}; + +use bytes::{Bytes, BytesMut}; +use rstest::rstest; +use socket2::SockRef; +use tokio::net::{TcpListener, TcpStream}; +use tokio_util::codec::{Decoder, Encoder}; + +use super::*; +use crate::frame::{Endianness, LengthFormat}; + +const MIN_FRAME_LENGTH: usize = 64; +const MAX_FRAME_LENGTH: usize = 16 * 1024 * 1024; +const DEFAULT_MAX_FRAME_LENGTH: usize = 1024; +const KEEPALIVE_DURATION: Duration = Duration::from_secs(30); +const LINGER_DURATION: Duration = Duration::from_secs(1); +const BUFFER_SIZE_U32: u32 = 256 * 1024; +const BUFFER_SIZE_USIZE: usize = 256 * 1024; + +async fn spawn_listener() -> (SocketAddr, tokio::task::JoinHandle) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind listener"); + let addr = listener.local_addr().expect("listener addr"); + let accept = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept client"); + stream + }); + (addr, accept) +} + +/// Helper function to test that a builder option is correctly applied to the TCP socket. +async fn assert_builder_option(configure_builder: F, assert_option: A) +where + F: FnOnce(WireframeClientBuilder) -> WireframeClientBuilder, + A: FnOnce(&WireframeClient), +{ + let (addr, accept) = spawn_listener().await; + + let client = configure_builder(WireframeClient::builder()) + .connect(addr) + .await + .expect("connect client"); + + assert_option(&client); + + let _server_stream = accept.await.expect("join accept task"); +} + +macro_rules! socket_option_test { + ($name:ident, $configure:expr, $assert:expr $(,)?) => { + #[tokio::test] + async fn $name() { assert_builder_option($configure, $assert).await; } + }; +} + +#[rstest] +#[case(1, MIN_FRAME_LENGTH)] +#[case(MIN_FRAME_LENGTH, MIN_FRAME_LENGTH)] +#[case(MAX_FRAME_LENGTH + 1, MAX_FRAME_LENGTH)] +fn codec_config_clamps_max_frame_length(#[case] input: usize, #[case] expected: usize) { + let config = ClientCodecConfig::default().max_frame_length(input); + assert_eq!(config.max_frame_length_value(), expected); +} + +#[test] +fn codec_config_defaults_match_server_buffer_capacity() { + let config = ClientCodecConfig::default(); + assert_eq!(config.max_frame_length_value(), DEFAULT_MAX_FRAME_LENGTH); + assert_eq!(config.length_format_value().bytes(), 4); + assert_eq!(config.length_format_value().endianness(), Endianness::Big); +} + +#[test] +fn build_codec_configures_length_delimited_codec() { + let config = ClientCodecConfig::default(); + let mut codec = config.build_codec(); + + let payload = Bytes::from_static(b"hello"); + let mut buf = BytesMut::new(); + + codec + .encode(payload.clone(), &mut buf) + .expect("encoding frame should succeed"); + + assert!( + buf.len() >= 4, + "encoded frame must at least contain the 4-byte length prefix" + ); + + let bytes = Bytes::from(buf.clone()); + let (len_prefix, data) = bytes.split_at(4); + let mut expected_prefix = BytesMut::new(); + LengthFormat::u32_be() + .write_len(payload.len(), &mut expected_prefix) + .expect("write length prefix"); + let expected_len_prefix = expected_prefix.freeze(); + assert_eq!( + len_prefix, expected_len_prefix, + "length prefix should be 4-byte big-endian" + ); + assert_eq!( + data, payload, + "payload bytes after the length prefix should be unchanged" + ); + + let mut decode_buf = buf; + let decoded = codec + .decode(&mut decode_buf) + .expect("decoding frame should succeed") + .expect("a frame should be produced"); + + assert_eq!(decoded, payload, "decoded payload should match original"); +} + +socket_option_test!( + builder_applies_nodelay_option, + |builder| builder.nodelay(true), + |client| { + let stream = client.tcp_stream().nodelay().expect("read nodelay"); + assert!(stream, "expected TCP_NODELAY to be enabled"); + }, +); + +socket_option_test!( + builder_applies_keepalive_option, + |builder| builder.keepalive(Some(KEEPALIVE_DURATION)), + |client| { + let sock_ref = SockRef::from(client.tcp_stream()); + assert!( + sock_ref.keepalive().expect("query SO_KEEPALIVE"), + "SO_KEEPALIVE should be enabled when configured via builder" + ); + }, +); + +socket_option_test!( + builder_applies_linger_option, + |builder| builder.linger(Some(LINGER_DURATION)), + |client| { + let sock_ref = SockRef::from(client.tcp_stream()); + assert_eq!( + sock_ref.linger().expect("query SO_LINGER"), + Some(LINGER_DURATION), + "SO_LINGER should match builder configuration" + ); + }, +); + +socket_option_test!( + builder_applies_send_buffer_size_option, + |builder| builder.send_buffer_size(BUFFER_SIZE_U32), + |client| { + let sock_ref = SockRef::from(client.tcp_stream()); + assert!( + sock_ref.send_buffer_size().expect("query SO_SNDBUF") >= BUFFER_SIZE_USIZE, + "SO_SNDBUF should be at least the requested builder value" + ); + }, +); + +socket_option_test!( + builder_applies_recv_buffer_size_option, + |builder| builder.recv_buffer_size(BUFFER_SIZE_U32), + |client| { + let sock_ref = SockRef::from(client.tcp_stream()); + assert!( + sock_ref.recv_buffer_size().expect("query SO_RCVBUF") >= BUFFER_SIZE_USIZE, + "SO_RCVBUF should be at least the requested builder value" + ); + }, +); + +socket_option_test!( + builder_applies_reuseaddr_option, + |builder| builder.reuseaddr(true), + |client| { + let sock_ref = SockRef::from(client.tcp_stream()); + assert!( + sock_ref.reuse_address().expect("query SO_REUSEADDR"), + "SO_REUSEADDR should be enabled when configured via builder" + ); + }, +); + +#[cfg(all( + unix, + not(target_os = "solaris"), + not(target_os = "illumos"), + not(target_os = "cygwin"), +))] +socket_option_test!( + builder_applies_reuseport_option, + |builder| builder.reuseport(true), + |client| { + let sock_ref = SockRef::from(client.tcp_stream()); + assert!( + sock_ref.reuse_port().expect("query SO_REUSEPORT"), + "SO_REUSEPORT should be enabled when configured via builder" + ); + }, +); diff --git a/src/lib.rs b/src/lib.rs index b748c0b5..c7cbed5b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod byte_order; /// Result type alias re-exported for convenience when working with the /// application builder. pub use app::error::Result; +pub mod client; pub mod serializer; pub use serializer::{BincodeSerializer, Serializer}; pub mod connection; @@ -30,6 +31,7 @@ pub mod rewind_stream; pub mod server; pub mod session; +pub use client::{ClientCodecConfig, ClientError, SocketOptions, WireframeClient}; pub use connection::ConnectionActor; pub use correlation::CorrelatableFrame; pub use fragment::{ diff --git a/tests/client_runtime.rs b/tests/client_runtime.rs new file mode 100644 index 00000000..0d5703ed --- /dev/null +++ b/tests/client_runtime.rs @@ -0,0 +1,55 @@ +//! Integration tests for the wireframe client runtime. + +use futures::StreamExt; +use tokio::net::TcpListener; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; +use wireframe::client::{ClientCodecConfig, ClientError, WireframeClient}; + +#[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq, Eq)] +struct ClientPayload { + data: Vec, +} + +#[tokio::test] +async fn client_surfaces_error_when_frame_exceeds_server_max_length() { + let server_max_frame_length = 64usize; + let client_max_frame_length = 1024usize; + let oversized_payload_len = 128usize; + + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test listener"); + let addr = listener + .local_addr() + .expect("read local address for test listener"); + + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("server accepts connection"); + let codec = LengthDelimitedCodec::builder() + .max_frame_length(server_max_frame_length) + .new_codec(); + let mut framed = Framed::new(stream, codec); + let result = framed.next().await; + assert!( + matches!(result, Some(Err(_))), + "server should reject oversized frame" + ); + }); + + let mut client = WireframeClient::builder() + .codec_config(ClientCodecConfig::default().max_frame_length(client_max_frame_length)) + .connect(addr) + .await + .expect("connect client"); + let payload = ClientPayload { + data: vec![7_u8; oversized_payload_len], + }; + + let result: Result = client.call(&payload).await; + assert!( + matches!(result, Err(ClientError::Disconnected | ClientError::Io(_))), + "client should surface transport or disconnect error" + ); + + server.await.expect("join server task"); +} diff --git a/tests/connection.rs b/tests/connection.rs index 4426acaa..99598782 100644 --- a/tests/connection.rs +++ b/tests/connection.rs @@ -212,12 +212,16 @@ fn assert_reason_logged( expected_reason: &str, expected_correlation: Option, ) { + let expected_correlation = format!("correlation_id={expected_correlation:?}"); let mut found = false; while let Some(record) = logger.pop() { let message = record.args().to_string(); if !message.contains("multi-packet stream closed") { continue; } + if !message.contains(&expected_correlation) { + continue; + } assert_eq!( record.level(), expected_level, @@ -227,10 +231,6 @@ fn assert_reason_logged( message.contains(&format!("reason={expected_reason}")), "closure log missing reason: message={message}", ); - assert!( - message.contains(&format!("correlation_id={expected_correlation:?}")), - "closure log missing correlation: message={message}", - ); found = true; break; } diff --git a/tests/cucumber.rs b/tests/cucumber.rs index 223ea483..2128be52 100644 --- a/tests/cucumber.rs +++ b/tests/cucumber.rs @@ -1,12 +1,13 @@ #![cfg(not(loom))] //! Cucumber test runner for integration tests. //! -//! Orchestrates five distinct test suites: +//! Orchestrates six distinct test suites: //! - `PanicWorld`: Tests server resilience during connection panics //! - `CorrelationWorld`: Tests correlation ID propagation in multi-frame responses //! - `StreamEndWorld`: Verifies end-of-stream signalling //! - `MultiPacketWorld`: Tests channel-backed multi-packet response delivery //! - `FragmentWorld`: Tests fragment metadata enforcement and reassembly primitives +//! - `ClientRuntimeWorld`: Tests client runtime configuration and framing behaviour //! //! # Example //! @@ -17,6 +18,7 @@ //! tests/features/stream_end.feature -> StreamEndWorld context //! tests/features/multi_packet.feature -> MultiPacketWorld context //! tests/features/fragment.feature -> FragmentWorld context +//! tests/features/client_runtime.feature -> ClientRuntimeWorld context //! ``` //! //! Each context provides specialised step definitions and state management @@ -26,7 +28,14 @@ mod steps; mod world; use cucumber::World; -use world::{CorrelationWorld, FragmentWorld, MultiPacketWorld, PanicWorld, StreamEndWorld}; +use world::{ + ClientRuntimeWorld, + CorrelationWorld, + FragmentWorld, + MultiPacketWorld, + PanicWorld, + StreamEndWorld, +}; #[tokio::main] async fn main() { @@ -35,4 +44,5 @@ async fn main() { StreamEndWorld::run("tests/features/stream_end.feature").await; MultiPacketWorld::run("tests/features/multi_packet.feature").await; FragmentWorld::run("tests/features/fragment.feature").await; + ClientRuntimeWorld::run("tests/features/client_runtime.feature").await; } diff --git a/tests/features/client_runtime.feature b/tests/features/client_runtime.feature new file mode 100644 index 00000000..b4aa4ff0 --- /dev/null +++ b/tests/features/client_runtime.feature @@ -0,0 +1,12 @@ +Feature: Wireframe client runtime + Scenario: Client sends and receives with configured frame length + Given a wireframe echo server allowing frames up to 2048 bytes + And a wireframe client configured with max frame length 2048 + When the client sends a payload of 1500 bytes + Then the client receives the echoed payload + + Scenario: Client reports errors when server frame limit is exceeded + Given a wireframe echo server allowing frames up to 64 bytes + And a wireframe client configured with max frame length 1024 + When the client sends an oversized payload of 128 bytes + Then the client reports a framing error diff --git a/tests/steps/client_steps.rs b/tests/steps/client_steps.rs new file mode 100644 index 00000000..dc9684ae --- /dev/null +++ b/tests/steps/client_steps.rs @@ -0,0 +1,35 @@ +//! Steps for wireframe client runtime behavioural tests. + +use cucumber::{given, then, when}; + +use crate::world::{ClientRuntimeWorld, TestResult}; + +#[given(expr = "a wireframe echo server allowing frames up to {int} bytes")] +async fn given_server(world: &mut ClientRuntimeWorld, max_frame_length: usize) -> TestResult { + world.start_server(max_frame_length).await +} + +#[given(expr = "a wireframe client configured with max frame length {int}")] +async fn given_client(world: &mut ClientRuntimeWorld, max_frame_length: usize) -> TestResult { + world.connect_client(max_frame_length).await +} + +#[when(expr = "the client sends a payload of {int} bytes")] +async fn when_send_payload(world: &mut ClientRuntimeWorld, size: usize) -> TestResult { + world.send_payload(size).await +} + +#[when(expr = "the client sends an oversized payload of {int} bytes")] +async fn when_send_oversized_payload(world: &mut ClientRuntimeWorld, size: usize) -> TestResult { + world.send_payload_expect_error(size).await +} + +#[then("the client receives the echoed payload")] +async fn then_receives_echo(world: &mut ClientRuntimeWorld) -> TestResult { + world.verify_echo().await +} + +#[then("the client reports a framing error")] +async fn then_reports_error(world: &mut ClientRuntimeWorld) -> TestResult { + world.verify_error().await +} diff --git a/tests/steps/mod.rs b/tests/steps/mod.rs index 62c5b092..6664d728 100644 --- a/tests/steps/mod.rs +++ b/tests/steps/mod.rs @@ -3,6 +3,7 @@ //! This module exposes all Given-When-Then steps used by the //! behaviour-driven tests under `tests/features`. +mod client_steps; mod correlation_steps; mod fragment_steps; mod multi_packet_steps; diff --git a/tests/world.rs b/tests/world.rs index 5ce08561..21b1b441 100644 --- a/tests/world.rs +++ b/tests/world.rs @@ -5,6 +5,7 @@ mod worlds; pub use worlds::{ + client_runtime::ClientRuntimeWorld, common::TestResult, correlation::CorrelationWorld, fragment::FragmentWorld, diff --git a/tests/worlds/client_runtime.rs b/tests/worlds/client_runtime.rs new file mode 100644 index 00000000..a7ff817d --- /dev/null +++ b/tests/worlds/client_runtime.rs @@ -0,0 +1,152 @@ +//! Test world for client runtime scenarios. +#![cfg(not(loom))] + +use std::net::SocketAddr; + +use cucumber::World; +use futures::{SinkExt, StreamExt}; +use log::warn; +use tokio::{net::TcpListener, task::JoinHandle}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; +use wireframe::client::{ClientCodecConfig, ClientError, WireframeClient}; + +use super::TestResult; + +/// Test world exercising the wireframe client runtime. +#[derive(Debug, Default, World)] +pub struct ClientRuntimeWorld { + addr: Option, + server: Option>, + client: Option, + payload: Option, + response: Option, + last_error: Option, +} + +#[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq, Eq, Clone)] +struct ClientPayload { + data: Vec, +} + +impl ClientRuntimeWorld { + /// Start an echo server with the specified maximum frame length. + /// + /// # Errors + /// Returns an error if binding or spawning the server fails. + pub async fn start_server(&mut self, max_frame_length: usize) -> TestResult { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let handle = tokio::spawn(async move { + let Ok((stream, _)) = listener.accept().await else { + warn!("client runtime server failed to accept connection"); + return; + }; + let codec = LengthDelimitedCodec::builder() + .max_frame_length(max_frame_length) + .new_codec(); + let mut framed = Framed::new(stream, codec); + let Some(result) = framed.next().await else { + warn!("client runtime server closed before receiving a frame"); + return; + }; + let Ok(frame) = result else { + warn!("client runtime server failed to decode frame"); + return; + }; + if let Err(err) = framed.send(frame.freeze()).await { + warn!("client runtime server failed to send response: {err:?}"); + } + }); + + self.addr = Some(addr); + self.server = Some(handle); + Ok(()) + } + + /// Connect a client using the specified maximum frame length. + /// + /// # Errors + /// Returns an error if the server has not started or the client fails to connect. + pub async fn connect_client(&mut self, max_frame_length: usize) -> TestResult { + let addr = self.addr.ok_or("server address missing")?; + let codec_config = ClientCodecConfig::default().max_frame_length(max_frame_length); + let client = WireframeClient::builder() + .codec_config(codec_config) + .connect(addr) + .await?; + self.client = Some(client); + Ok(()) + } + + /// Send a payload of the specified size and capture the response. + /// + /// # Errors + /// Returns an error if the client is missing or communication fails. + pub async fn send_payload(&mut self, size: usize) -> TestResult { + let payload = ClientPayload { + data: vec![7_u8; size], + }; + let client = self.client.as_mut().ok_or("client not connected")?; + let response: ClientPayload = client.call(&payload).await?; + self.payload = Some(payload); + self.response = Some(response); + self.last_error = None; + Ok(()) + } + + /// Send a payload that should exceed the peer's frame limit. + /// + /// # Errors + /// Returns an error if the client is missing or if no failure is observed. + pub async fn send_payload_expect_error(&mut self, size: usize) -> TestResult { + let payload = ClientPayload { + data: vec![7_u8; size], + }; + let client = self.client.as_mut().ok_or("client not connected")?; + let result: Result = client.call(&payload).await; + match result { + Ok(_) => return Err("expected client error for oversized payload".into()), + Err(err) => self.last_error = Some(err), + } + Ok(()) + } + + /// Verify that the client received the echoed payload. + /// + /// # Errors + /// Returns an error if the response is missing or mismatched. + pub async fn verify_echo(&mut self) -> TestResult { + let payload = self.payload.as_ref().ok_or("payload missing")?; + let response = self.response.as_ref().ok_or("response missing")?; + if payload != response { + return Err("response did not match payload".into()); + } + self.await_server().await?; + Ok(()) + } + + /// Verify that a client error was captured. + /// + /// # Errors + /// Returns an error if no failure was observed. + pub async fn verify_error(&mut self) -> TestResult { + let err = self + .last_error + .as_ref() + .ok_or("expected client error was not captured")?; + if !matches!(err, ClientError::Disconnected | ClientError::Io(_)) { + return Err("unexpected client error variant".into()); + } + self.await_server().await?; + Ok(()) + } + + async fn await_server(&mut self) -> TestResult { + if let Some(handle) = self.server.take() { + handle + .await + .map_err(|err| format!("server task failed: {err}"))?; + } + Ok(()) + } +} diff --git a/tests/worlds/mod.rs b/tests/worlds/mod.rs index 99a2ccaa..e1403ba2 100644 --- a/tests/worlds/mod.rs +++ b/tests/worlds/mod.rs @@ -26,6 +26,7 @@ pub(crate) fn build_small_queues() support::builder::().unlimited().build() } +pub mod client_runtime; pub mod correlation; pub mod fragment; pub mod multi_packet;