diff --git a/README.md b/README.md index e277bcfa..a9820cc1 100644 --- a/README.md +++ b/README.md @@ -95,17 +95,22 @@ payload bytes. Applications can supply their own envelope type by calling `Packet` trait: ```rust -use wireframe::app::{Packet, WireframeApp}; +use wireframe::app::{Packet, PacketParts, WireframeApp}; #[derive(bincode::Encode, bincode::BorrowDecode)] -struct MyEnv { id: u32, correlation_id: u64, data: Vec } +struct MyEnv { id: u32, correlation_id: Option, payload: Vec } impl Packet for MyEnv { fn id(&self) -> u32 { self.id } - fn correlation_id(&self) -> u64 { self.correlation_id } - fn into_parts(self) -> (u32, u64, Vec) { (self.id, self.correlation_id, self.data) } - fn from_parts(id: u32, correlation_id: u64, data: Vec) -> Self { - Self { id, correlation_id, data } + fn correlation_id(&self) -> Option { self.correlation_id } + fn into_parts(self) -> PacketParts { + PacketParts::new(self.id, self.correlation_id, self.payload) + } + fn from_parts(parts: PacketParts) -> Self { + let id = parts.id(); + let correlation_id = parts.correlation_id(); + let payload = parts.payload(); + Self { id, correlation_id, payload } } } @@ -115,6 +120,10 @@ let app = WireframeApp::<_, _, MyEnv>::new() .unwrap(); ``` +A `None` correlation ID denotes an unsolicited event or server-initiated push. +Use `None` rather than `Some(0)` when a frame lacks a correlation ID. See +[PacketParts](docs/api.md#packetparts) for field details. + This allows integration with existing packet formats without modifying `handle_frame`. @@ -281,8 +290,9 @@ Example programs are available in the `examples/` directory: - `ping_pong.rs` — showcases serialization and middleware in a ping/pong protocol. See [examples/ping_pong.md](examples/ping_pong.md) for a detailed overview. -- `packet_enum.rs` – shows packet type discrimination with a bincode enum and a - frame containing container types like `HashMap` and `Vec`. +- [`packet_enum.rs`](examples/packet_enum.rs) — shows packet type discrimination + with a bincode enum and a frame containing container types like `HashMap` and + `Vec`. Run an example with Cargo: diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 00000000..120e05d0 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,25 @@ +# API overview + +## PacketParts + +A `PacketParts` struct decomposes a packet into its components: + +```rust +let parts = PacketParts::new(id, correlation_id, payload); +``` + +- `id: u32` — frame identifier +- `correlation_id: Option` — `None` marks an unsolicited event or + server‑initiated push +- `payload: Vec` — raw message bytes + +Custom packet types can convert to and from `PacketParts` to avoid manual +mapping: + +```rust +let parts = PacketParts::new(id, None, data); +let env = Envelope::from(parts); +``` + +`None` propagation ensures packets that originate on the server carry no +accidental correlation identifier. diff --git a/docs/roadmap.md b/docs/roadmap.md index 9168bfe5..729a2c36 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -155,9 +155,15 @@ production environments. - [x] Expose key operational metrics (e.g., active connections, messages per second, error rates). - - [x] Provide an integration guide for popular monitoring systems (e.g., +- [x] Provide an integration guide for popular monitoring systems (e.g., Prometheus). +- [x] **Packet decomposition:** + + - [x] Introduce `PacketParts` to replace tuple-based packet handling. + - [x] Treat `correlation_id` as `Option` so `None` denotes an + unsolicited event or server-initiated push. + - [x] **Advanced Error Handling:** - [x] Implement panic handlers in connection tasks to prevent a single diff --git a/docs/rust-binary-router-library-design.md b/docs/rust-binary-router-library-design.md index fa9d4ab6..15df8665 100644 --- a/docs/rust-binary-router-library-design.md +++ b/docs/rust-binary-router-library-design.md @@ -498,6 +498,53 @@ frame processing, akin to how `tokio-util::codec` operates, endows "wireframe" with the necessary flexibility to adapt to this diversity without embedding assumptions about any single framing strategy into its core. +#### 4.3.1 Packet abstraction + +The library defines a `Packet` trait to represent transport frames. Frames can +be decomposed into `PacketParts` for efficient handling and reassembly. +`Envelope` is the default implementation used by `wireframe`. The following +diagram depicts the `Packet` trait, `PacketParts`, and `Envelope`. + +```mermaid +classDiagram + class Packet { + <> + +id() u32 + +correlation_id() Option + +into_parts() PacketParts + +from_parts(parts: PacketParts) Self + } + class PacketParts { + -id: u32 + -correlation_id: Option + -payload: Vec + +new(id: u32, correlation_id: Option, payload: Vec) PacketParts + +id() u32 + +correlation_id() Option + +payload() Vec + +inherit_correlation(source: Option) PacketParts + } + class Envelope { + -id: u32 + -correlation_id: Option + -payload: Vec + +new(id: u32, correlation_id: Option, payload: Vec) + +from_parts(parts: PacketParts) Envelope + +into_parts() PacketParts + } + Packet <|.. Envelope + PacketParts <.. Envelope : uses + PacketParts <.. Packet : uses +``` + +`Envelope` implements `Packet`, carrying payload and metadata through the +system. `PacketParts` avoids repetitive tuple unpacking when frames are split +into constituent pieces. A `None` correlation ID denotes an unsolicited event +or server-initiated push. In multi-packet streaming responses, the optional +`correlation_id` links all packets in the stream to the originating request, +and protocols should define an explicit end-of-stream indicator alongside the +shared correlation identifier. + ### 4.4. Message Serialization and Deserialization The conversion of frame payloads to and from Rust types is a critical source of diff --git a/examples/metadata_routing.rs b/examples/metadata_routing.rs index b2f77eca..aae264aa 100644 --- a/examples/metadata_routing.rs +++ b/examples/metadata_routing.rs @@ -52,7 +52,7 @@ impl FrameMetadata for HeaderSerializer { // `parse` receives the complete frame because `LengthPrefixedProcessor` // ensures `src` contains exactly one message. Returning `src.len()` is // therefore correct for this demo. - Ok((Envelope::new(id, 0, payload), src.len())) + Ok((Envelope::new(id, None, payload), src.len())) } } @@ -62,7 +62,7 @@ struct Ping; #[tokio::main] async fn main() -> io::Result<()> { let app = WireframeApp::new() - .unwrap() + .expect("failed to create app") .frame_processor(LengthPrefixedProcessor::default()) .serializer(HeaderSerializer) .route( @@ -73,7 +73,7 @@ async fn main() -> io::Result<()> { }) }), ) - .unwrap() + .expect("failed to add ping route") .route( 2, Arc::new(|_env: &Envelope| { @@ -82,14 +82,14 @@ async fn main() -> io::Result<()> { }) }), ) - .unwrap(); + .expect("failed to add pong route"); let (mut client, server) = duplex(1024); let server_task = tokio::spawn(async move { app.handle_connection(server).await; }); - let payload = Ping.to_bytes().unwrap(); + let payload = Ping.to_bytes().expect("failed to serialize Ping message"); let mut frame = Vec::new(); frame.extend_from_slice(&1u16.to_be_bytes()); frame.push(0); @@ -98,11 +98,11 @@ async fn main() -> io::Result<()> { let mut bytes = BytesMut::new(); LengthPrefixedProcessor::default() .encode(&frame, &mut bytes) - .unwrap(); + .expect("failed to encode frame"); client.write_all(&bytes).await?; client.shutdown().await?; - server_task.await.unwrap(); + server_task.await.expect("server task failed"); Ok(()) } diff --git a/examples/ping_pong.rs b/examples/ping_pong.rs index ed39b217..61763d94 100644 --- a/examples/ping_pong.rs +++ b/examples/ping_pong.rs @@ -57,13 +57,15 @@ where type Error = std::convert::Infallible; async fn call(&self, req: ServiceRequest) -> Result { + let cid = req.correlation_id(); let (ping_req, _) = match Ping::from_bytes(req.frame()) { Ok(val) => val, Err(e) => { eprintln!("failed to decode ping: {e:?}"); - return Ok(ServiceResponse::new(encode_error(format!( - "decode error: {e:?}" - )))); + return Ok(ServiceResponse::new( + encode_error(format!("decode error: {e:?}")), + cid, + )); } }; let mut response = self.inner.call(req).await?; @@ -71,15 +73,16 @@ where Pong(v) } else { eprintln!("ping overflowed at {}", ping_req.0); - return Ok(ServiceResponse::new(encode_error("overflow"))); + return Ok(ServiceResponse::new(encode_error("overflow"), cid)); }; match pong_resp.to_bytes() { Ok(bytes) => *response.frame_mut() = bytes, Err(e) => { eprintln!("failed to encode pong: {e:?}"); - return Ok(ServiceResponse::new(encode_error(format!( - "encode error: {e:?}" - )))); + return Ok(ServiceResponse::new( + encode_error(format!("encode error: {e:?}")), + cid, + )); } } Ok(response) diff --git a/src/app.rs b/src/app.rs index 7d23028c..02e0b9b2 100644 --- a/src/app.rs +++ b/src/app.rs @@ -151,7 +151,10 @@ impl From for SendError { /// # Example /// /// ``` -/// use wireframe::{app::Packet, message::Message}; +/// use wireframe::{ +/// app::{Packet, PacketParts}, +/// message::Message, +/// }; /// /// #[derive(bincode::Decode, bincode::Encode)] /// struct CustomEnvelope { @@ -163,15 +166,14 @@ impl From for SendError { /// impl Packet for CustomEnvelope { /// fn id(&self) -> u32 { self.id } /// -/// fn correlation_id(&self) -> u64 { 0 } +/// fn correlation_id(&self) -> Option { None } /// -/// fn into_parts(self) -> (u32, u64, Vec) { (self.id, 0, self.payload) } +/// fn into_parts(self) -> PacketParts { PacketParts::new(self.id, None, self.payload) } /// -/// fn from_parts(id: u32, correlation_id: u64, msg: Vec) -> Self { -/// let _ = correlation_id; +/// fn from_parts(parts: PacketParts) -> Self { /// Self { -/// id, -/// payload: msg, +/// id: parts.id(), +/// payload: parts.payload(), /// timestamp: 0, /// } /// } @@ -182,58 +184,122 @@ pub trait Packet: Message + Send + Sync + 'static { fn id(&self) -> u32; /// Return the correlation identifier tying this frame to a request. - fn correlation_id(&self) -> u64; + fn correlation_id(&self) -> Option; /// Consume the packet and return its identifier, correlation id and payload bytes. - fn into_parts(self) -> (u32, u64, Vec); + fn into_parts(self) -> PacketParts; - /// Construct a new packet from id, correlation id and raw payload bytes. - fn from_parts(id: u32, correlation_id: u64, msg: Vec) -> Self; + /// Construct a new packet from raw parts. + fn from_parts(parts: PacketParts) -> Self; +} + +/// Component values extracted from or used to build a [`Packet`]. +#[derive(Debug)] +pub struct PacketParts { + id: u32, + correlation_id: Option, + payload: Vec, } /// Basic envelope type used by [`handle_connection`]. /// /// Incoming frames are deserialized into an `Envelope` containing the /// message identifier and raw payload bytes. -#[derive(bincode::Decode, bincode::Encode, Copy, Clone, Debug)] -pub struct PacketHeader { - pub(crate) id: u32, - pub(crate) correlation_id: u64, -} - #[derive(bincode::Decode, bincode::Encode, Debug)] pub struct Envelope { - pub(crate) header: PacketHeader, - pub(crate) msg: Vec, + pub(crate) id: u32, + pub(crate) correlation_id: Option, + pub(crate) payload: Vec, } impl Envelope { /// Create a new [`Envelope`] with the provided identifiers and payload. #[must_use] - pub fn new(id: u32, correlation_id: u64, msg: Vec) -> Self { + pub fn new(id: u32, correlation_id: Option, payload: Vec) -> Self { Self { - header: PacketHeader { id, correlation_id }, - msg, + id, + correlation_id, + payload, } } - - /// Consume the envelope, returning its header and payload bytes. - #[must_use] - pub fn into_parts(self) -> (PacketHeader, Vec) { (self.header, self.msg) } } impl Packet for Envelope { - fn id(&self) -> u32 { self.header.id } + #[inline] + fn id(&self) -> u32 { self.id } - fn correlation_id(&self) -> u64 { self.header.correlation_id } + #[inline] + fn correlation_id(&self) -> Option { self.correlation_id } - fn into_parts(self) -> (u32, u64, Vec) { - let (header, msg) = Envelope::into_parts(self); - (header.id, header.correlation_id, msg) + fn into_parts(self) -> PacketParts { self.into() } + + fn from_parts(parts: PacketParts) -> Self { parts.into() } +} + +impl PacketParts { + /// Construct a new set of packet parts. + #[must_use] + pub fn new(id: u32, correlation_id: Option, payload: Vec) -> Self { + Self { + id, + correlation_id, + payload, + } } - fn from_parts(id: u32, correlation_id: u64, msg: Vec) -> Self { - Envelope::new(id, correlation_id, msg) + #[must_use] + pub const fn id(&self) -> u32 { self.id } + + #[must_use] + pub const fn correlation_id(&self) -> Option { self.correlation_id } + + #[must_use] + pub fn payload(self) -> Vec { self.payload } + + /// Ensure a correlation identifier is present, inheriting from `source` if missing. + /// + /// # Examples + /// ``` + /// use wireframe::app::PacketParts; + /// // Inherit when missing + /// let parts = PacketParts::new(1, None, vec![]).inherit_correlation(Some(42)); + /// assert_eq!(parts.correlation_id(), Some(42)); + /// + /// // Overwrite mismatched value + /// let parts = PacketParts::new(1, Some(7), vec![]).inherit_correlation(Some(8)); + /// assert_eq!(parts.correlation_id(), Some(8)); + /// ``` + #[must_use] + pub fn inherit_correlation(mut self, source: Option) -> Self { + match (self.correlation_id, source) { + (None, cid) => self.correlation_id = cid, + (Some(cid), Some(src)) if cid != src => { + tracing::warn!( + id = self.id, + expected = src, + found = cid, + "mismatched correlation id in response", + ); + // Overwrite with the source correlation ID to ensure downstream + // consistency. + self.correlation_id = Some(src); + } + _ => {} + } + self + } +} + +impl From for PacketParts { + fn from(e: Envelope) -> Self { PacketParts::new(e.id, e.correlation_id, e.payload) } +} + +impl From for Envelope { + fn from(p: PacketParts) -> Self { + let id = p.id(); + let correlation_id = p.correlation_id(); + let payload = p.payload(); + Envelope::new(id, correlation_id, payload) } } @@ -252,7 +318,7 @@ where E: Packet, { /// - /// Initialises empty routes, services, middleware, and application data. + /// Initializes empty routes, services, middleware, and application data. /// Sets the default frame processor and serializer, with no connection /// lifecycle hooks. fn default() -> Self { @@ -290,19 +356,21 @@ where /// #[derive(bincode::Encode, bincode::BorrowDecode)] /// struct MyEnv { /// id: u32, - /// correlation_id: u64, + /// correlation_id: Option, /// data: Vec, /// } /// /// impl Packet for MyEnv { /// fn id(&self) -> u32 { self.id } - /// fn correlation_id(&self) -> u64 { self.correlation_id } - /// fn into_parts(self) -> (u32, u64, Vec) { (self.id, self.correlation_id, self.data) } - /// fn from_parts(id: u32, correlation_id: u64, data: Vec) -> Self { + /// fn correlation_id(&self) -> Option { self.correlation_id } + /// fn into_parts(self) -> PacketParts { + /// PacketParts::new(self.id, self.correlation_id, self.data) + /// } + /// fn from_parts(parts: PacketParts) -> Self { /// Self { - /// id, - /// correlation_id, - /// data, + /// id: parts.id(), + /// correlation_id: parts.correlation_id(), + /// data: parts.payload(), /// } /// } /// } @@ -585,7 +653,7 @@ where let routes = self.build_chains().await; if let Err(e) = self.process_stream(&mut stream, &routes).await { - tracing::warn!(error = ?e, "connection terminated with error"); + tracing::warn!(correlation_id = ?None::, error = ?e, "connection terminated with error"); } if let (Some(teardown), Some(state)) = (&self.on_disconnect, state) { @@ -713,7 +781,7 @@ where } Err(e) => { *deser_failures += 1; - tracing::warn!(error = ?e, "failed to deserialize message"); + tracing::warn!(correlation_id = ?None::, error = ?e, "failed to deserialize message"); crate::metrics::inc_deser_errors(); if *deser_failures >= MAX_DESER_FAILURES { return Err(io::Error::new( @@ -725,24 +793,31 @@ where } }; - if let Some(service) = routes.get(&env.header.id) { - let request = ServiceRequest::new(env.msg, env.header.correlation_id); + if let Some(service) = routes.get(&env.id) { + let request = ServiceRequest::new(env.payload, env.correlation_id); match service.call(request).await { Ok(resp) => { - let response = - Envelope::new(env.header.id, env.header.correlation_id, resp.into_inner()); + let parts = PacketParts::new(env.id, resp.correlation_id(), resp.into_inner()) + .inherit_correlation(env.correlation_id); + let correlation_id = parts.correlation_id(); + let response = Envelope::from_parts(parts); if let Err(e) = self.send_response(stream, &response).await { - tracing::warn!(error = %e, "failed to send response"); + tracing::warn!( + id = env.id, + correlation_id = ?correlation_id, + error = ?e, + "failed to send response", + ); crate::metrics::inc_handler_errors(); } } Err(e) => { - tracing::warn!(id = env.header.id, error = ?e, "handler error"); + tracing::warn!(id = env.id, correlation_id = ?env.correlation_id, error = ?e, "handler error"); crate::metrics::inc_handler_errors(); } } } else { - tracing::warn!("no handler for message id {}", env.header.id); + tracing::warn!(id = env.id, correlation_id = ?env.correlation_id, "no handler for message id"); crate::metrics::inc_handler_errors(); } diff --git a/src/middleware.rs b/src/middleware.rs index de33f7ef..f5eca6f7 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -8,6 +8,8 @@ use std::{convert::Infallible, sync::Arc}; use async_trait::async_trait; +use crate::app::PacketParts; + /// Generic container for request and response frame data. #[derive(Debug, Default)] pub struct FrameContainer { @@ -36,13 +38,13 @@ impl FrameContainer { #[derive(Debug)] pub struct ServiceRequest { inner: FrameContainer>, - correlation_id: u64, + correlation_id: Option, } impl ServiceRequest { /// Create a new [`ServiceRequest`] from raw frame bytes. #[must_use] - pub fn new(frame: Vec, correlation_id: u64) -> Self { + pub fn new(frame: Vec, correlation_id: Option) -> Self { Self { inner: FrameContainer::new(frame), correlation_id, @@ -53,9 +55,16 @@ impl ServiceRequest { #[must_use] pub fn frame(&self) -> &[u8] { self.inner.frame().as_slice() } - /// Return the correlation identifier associated with this request. + /// Return the correlation identifier associated with this request, if any. #[must_use] - pub fn correlation_id(&self) -> u64 { self.correlation_id } + pub fn correlation_id(&self) -> Option { self.correlation_id } + + /// Set or clear the correlation identifier on the request. + #[must_use] + pub fn set_correlation_id(&mut self, correlation_id: Option) -> &mut Self { + self.correlation_id = correlation_id; + self + } /// Mutable access to the inner frame bytes. #[must_use] @@ -70,14 +79,16 @@ impl ServiceRequest { #[derive(Debug, Default)] pub struct ServiceResponse { inner: FrameContainer>, + correlation_id: Option, } impl ServiceResponse { /// Create a new [`ServiceResponse`] containing the given frame bytes. #[must_use] - pub fn new(frame: Vec) -> Self { + pub fn new(frame: Vec, correlation_id: Option) -> Self { Self { inner: FrameContainer::new(frame), + correlation_id, } } @@ -89,6 +100,17 @@ impl ServiceResponse { #[must_use] pub fn frame_mut(&mut self) -> &mut Vec { self.inner.frame_mut() } + /// Return the correlation identifier associated with this response, if any. + #[must_use] + pub fn correlation_id(&self) -> Option { self.correlation_id } + + /// Set or clear the correlation identifier. + #[must_use] + pub fn set_correlation_id(&mut self, correlation_id: Option) -> &mut Self { + self.correlation_id = correlation_id; + self + } + /// Consume the response, yielding the raw frame bytes. #[must_use] pub fn into_inner(self) -> Vec { self.inner.into_inner() } @@ -303,10 +325,16 @@ impl Service for RouteService { async fn call(&self, req: ServiceRequest) -> Result { // The handler only borrows the envelope, allowing us to consume it // afterwards to extract the response payload. - let env = E::from_parts(self.id, req.correlation_id(), req.into_inner()); + let env = E::from_parts(PacketParts::new( + self.id, + req.correlation_id(), + req.into_inner(), + )); (self.handler.as_ref())(&env).await; - let (_, _, bytes) = env.into_parts(); - Ok(ServiceResponse::new(bytes)) + let parts = env.into_parts(); + let correlation_id = parts.correlation_id(); + let payload = parts.payload(); + Ok(ServiceResponse::new(payload, correlation_id)) } } diff --git a/tests/correlation_id.rs b/tests/correlation_id.rs index 25316648..e515473b 100644 --- a/tests/correlation_id.rs +++ b/tests/correlation_id.rs @@ -12,13 +12,13 @@ use wireframe::{ async fn stream_frames_carry_request_correlation_id() { let cid = 42u64; let stream: FrameStream = Box::pin(try_stream! { - yield Envelope::new(1, cid, vec![1]); - yield Envelope::new(1, cid, vec![2]); + yield Envelope::new(1, Some(cid), vec![1]); + yield Envelope::new(1, Some(cid), vec![2]); }); let (queues, handle) = PushQueues::bounded(1, 1); let shutdown = CancellationToken::new(); let mut actor = ConnectionActor::new(queues, handle, Some(stream), shutdown); let mut out = Vec::new(); actor.run(&mut out).await.expect("actor run failed"); - assert!(out.iter().all(|e| e.correlation_id() == cid)); + assert!(out.iter().all(|e| e.correlation_id() == Some(cid))); } diff --git a/tests/lifecycle.rs b/tests/lifecycle.rs index c7dc00ba..6a00242f 100644 --- a/tests/lifecycle.rs +++ b/tests/lifecycle.rs @@ -13,8 +13,8 @@ use std::{ use bytes::BytesMut; use wireframe::{ - app::{Envelope, Packet, WireframeApp}, - frame::{FrameProcessor, LengthPrefixedProcessor}, + app::{Envelope, Packet, PacketParts, WireframeApp}, + frame::FrameProcessor, serializer::{BincodeSerializer, Serializer}, }; use wireframe_testing::{processor, run_app, run_with_duplex_server}; @@ -102,22 +102,27 @@ async fn teardown_without_setup_does_not_run() { #[derive(bincode::Encode, bincode::BorrowDecode, PartialEq, Debug)] struct StateEnvelope { id: u32, - correlation_id: u64, - msg: Vec, + correlation_id: Option, + payload: Vec, } -impl wireframe::app::Packet for StateEnvelope { +impl Packet for StateEnvelope { fn id(&self) -> u32 { self.id } - fn correlation_id(&self) -> u64 { self.correlation_id } + fn correlation_id(&self) -> Option { self.correlation_id } - fn into_parts(self) -> (u32, u64, Vec) { (self.id, self.correlation_id, self.msg) } + fn into_parts(self) -> PacketParts { + PacketParts::new(self.id, self.correlation_id, self.payload) + } - fn from_parts(id: u32, correlation_id: u64, msg: Vec) -> Self { + fn from_parts(parts: PacketParts) -> Self { + let id = parts.id(); + let correlation_id = parts.correlation_id(); + let payload = parts.payload(); Self { id, correlation_id, - msg, + payload, } } } @@ -134,21 +139,33 @@ async fn helpers_propagate_connection_state() { let env = StateEnvelope { id: 1, - correlation_id: 0, - msg: vec![1], + correlation_id: Some(0), + payload: vec![1], }; let bytes = BincodeSerializer .serialize(&env) .expect("failed to serialise envelope"); let mut frame = BytesMut::new(); - LengthPrefixedProcessor::default() - .encode(&bytes, &mut frame) + let proc = processor(); + proc.encode(&bytes, &mut frame) .expect("encode should succeed"); let out = run_app(app, vec![frame.to_vec()], None) .await .expect("app run failed"); assert!(!out.is_empty()); + + let mut buf = BytesMut::from(&out[..]); + let processor = processor(); + let decoded = processor + .decode(&mut buf) + .expect("decode failed") + .expect("frame missing"); + let (resp, _) = BincodeSerializer + .deserialize::(&decoded) + .expect("deserialize failed"); + assert_eq!(resp.correlation_id, Some(0)); + assert_eq!(setup.load(Ordering::SeqCst), 1); assert_eq!(teardown.load(Ordering::SeqCst), 1); } diff --git a/tests/metadata.rs b/tests/metadata.rs index eaa8dc8b..45efcecb 100644 --- a/tests/metadata.rs +++ b/tests/metadata.rs @@ -60,7 +60,7 @@ async fn metadata_parser_invoked_before_deserialize() { let serializer = CountingSerializer(counter.clone()); let app = mock_wireframe_app_with_serializer(serializer); - let env = Envelope::new(1, 0, vec![42]); + let env = Envelope::new(1, Some(0), vec![42]); let out = drive_with_bincode(app, env) .await @@ -105,7 +105,7 @@ async fn falls_back_to_deserialize_after_parse_error() { let serializer = FallbackSerializer(parse_calls.clone(), deser_calls.clone()); let app = mock_wireframe_app_with_serializer(serializer); - let env = Envelope::new(1, 0, vec![7]); + let env = Envelope::new(1, Some(0), vec![7]); let out = drive_with_bincode(app, env) .await diff --git a/tests/middleware.rs b/tests/middleware.rs index 86e0653f..33ef5924 100644 --- a/tests/middleware.rs +++ b/tests/middleware.rs @@ -3,6 +3,7 @@ //! Confirm that a custom middleware can modify requests and responses. use async_trait::async_trait; +use rstest::rstest; use wireframe::middleware::{Next, Service, ServiceRequest, ServiceResponse, Transform}; struct EchoService; @@ -12,7 +13,8 @@ impl Service for EchoService { type Error = std::convert::Infallible; async fn call(&self, req: ServiceRequest) -> Result { - Ok(ServiceResponse::new(req.into_inner())) + let cid = req.correlation_id(); + Ok(ServiceResponse::new(req.into_inner(), cid)) } } @@ -48,13 +50,26 @@ where } } +#[rstest(correlation_id => [None, Some(0), Some(42)])] #[tokio::test] -async fn middleware_modifies_request_and_response() { +async fn middleware_modifies_request_and_response_preserves_cid(correlation_id: Option) { let service = EchoService; let mw = ModifyMiddleware; let wrapped = mw.transform(service).await; - let request = ServiceRequest::new(vec![1, 2, 3], 0); + let request = ServiceRequest::new(vec![1, 2, 3], correlation_id); let response = wrapped.call(request).await.expect("middleware call failed"); + assert_eq!(response.frame(), &[1, 2, 3, b'!', b'?']); + assert_eq!(response.correlation_id(), correlation_id); +} + +#[test] +fn service_request_setter_updates_correlation_id() { + let mut req = ServiceRequest::new(vec![], None); + let _ = req.set_correlation_id(Some(7)); + assert_eq!(req.correlation_id(), Some(7)); + + let _ = req.set_correlation_id(None); + assert_eq!(req.correlation_id(), None); } diff --git a/tests/middleware_order.rs b/tests/middleware_order.rs index df33daec..6b66f55b 100644 --- a/tests/middleware_order.rs +++ b/tests/middleware_order.rs @@ -64,7 +64,7 @@ async fn middleware_applied_in_reverse_order() { let (mut client, server) = duplex(256); - let env = Envelope::new(1, 7, vec![b'X']); + let env = Envelope::new(1, Some(7), vec![b'X']); let serializer = BincodeSerializer; let bytes = serializer.serialize(&env).expect("serialization failed"); // Use the default 4-byte big-endian length prefix for framing @@ -88,6 +88,9 @@ async fn middleware_applied_in_reverse_order() { let (resp, _) = serializer .deserialize::(&frame) .expect("deserialize failed"); - let (_, bytes) = resp.into_parts(); - assert_eq!(bytes, vec![b'X', b'A', b'B', b'B', b'A']); + let parts = wireframe::app::Packet::into_parts(resp); + let correlation_id = parts.correlation_id(); + let payload = parts.payload(); + assert_eq!(payload, vec![b'X', b'A', b'B', b'B', b'A']); + assert_eq!(correlation_id, Some(7)); } diff --git a/tests/packet_parts.rs b/tests/packet_parts.rs new file mode 100644 index 00000000..727964b9 --- /dev/null +++ b/tests/packet_parts.rs @@ -0,0 +1,29 @@ +//! Tests for `PacketParts` conversions and helpers. + +use wireframe::app::{Envelope, Packet, PacketParts}; + +#[test] +fn envelope_from_parts_round_trip() { + let env = Envelope::new(2, Some(5), vec![1, 2]); + let parts = env.into_parts(); + let rebuilt = Envelope::from(parts); + let parts = rebuilt.into_parts(); + let id = parts.id(); + let correlation_id = parts.correlation_id(); + let payload = parts.payload(); + assert_eq!(id, 2); + assert_eq!(correlation_id, Some(5)); + assert_eq!(payload, vec![1, 2]); +} + +#[rstest::rstest( + start, source, expected, + case(PacketParts::new(1, None, vec![]), Some(42), Some(42)), + case(PacketParts::new(1, Some(7), vec![]), None, Some(7)), + case(PacketParts::new(1, None, vec![]), None, None), + case(PacketParts::new(1, Some(7), vec![]), Some(8), Some(8)), +)] +fn inherit_variants(start: PacketParts, source: Option, expected: Option) { + let got = start.inherit_correlation(source); + assert_eq!(got.correlation_id(), expected); +} diff --git a/tests/routes.rs b/tests/routes.rs index 0d4957dd..e57a3b63 100644 --- a/tests/routes.rs +++ b/tests/routes.rs @@ -11,32 +11,39 @@ use bytes::BytesMut; use rstest::rstest; use wireframe::{ Serializer, - app::WireframeApp, + app::{Packet, PacketParts, WireframeApp}, frame::{FrameProcessor, LengthPrefixedProcessor}, message::Message, serializer::BincodeSerializer, }; use wireframe_testing::{drive_with_bincode, drive_with_frames}; -#[derive(bincode::Encode, bincode::BorrowDecode, PartialEq, Debug)] +#[derive(bincode::Encode, bincode::BorrowDecode, PartialEq, Debug, Clone)] struct TestEnvelope { id: u32, - correlation_id: u64, - msg: Vec, + correlation_id: Option, + payload: Vec, } -impl wireframe::app::Packet for TestEnvelope { +impl Packet for TestEnvelope { + #[inline] fn id(&self) -> u32 { self.id } - fn correlation_id(&self) -> u64 { self.correlation_id } + #[inline] + fn correlation_id(&self) -> Option { self.correlation_id } - fn into_parts(self) -> (u32, u64, Vec) { (self.id, self.correlation_id, self.msg) } + fn into_parts(self) -> PacketParts { + PacketParts::new(self.id, self.correlation_id, self.payload) + } - fn from_parts(id: u32, correlation_id: u64, msg: Vec) -> Self { + fn from_parts(parts: PacketParts) -> Self { + let id = parts.id(); + let correlation_id = parts.correlation_id(); + let payload = parts.payload(); Self { id, correlation_id, - msg, + payload, } } } @@ -66,8 +73,8 @@ async fn handler_receives_message_and_echoes_response() { let msg_bytes = Echo(42).to_bytes().expect("encode failed"); let env = TestEnvelope { id: 1, - correlation_id: 99, - msg: msg_bytes, + correlation_id: Some(99), + payload: msg_bytes, }; let out = drive_with_bincode(app, env) @@ -82,12 +89,45 @@ async fn handler_receives_message_and_echoes_response() { let (resp_env, _) = BincodeSerializer .deserialize::(&frame) .expect("deserialize failed"); - assert_eq!(resp_env.correlation_id, 99); - let (echo, _) = Echo::from_bytes(&resp_env.msg).expect("decode echo failed"); + assert_eq!(resp_env.correlation_id, Some(99)); + let (echo, _) = Echo::from_bytes(&resp_env.payload).expect("decode echo failed"); assert_eq!(echo, Echo(42)); assert_eq!(called.load(Ordering::SeqCst), 1); } +#[tokio::test] +async fn handler_echoes_with_none_correlation_id() { + let app = WireframeApp::<_, _, TestEnvelope>::new() + .expect("failed to create app") + .frame_processor(LengthPrefixedProcessor::default()) + .route( + 1, + std::sync::Arc::new(|_: &TestEnvelope| Box::pin(async {})), + ) + .expect("route registration failed"); + + let msg_bytes = Echo(7).to_bytes().expect("encode failed"); + let env = TestEnvelope { + id: 1, + correlation_id: None, + payload: msg_bytes, + }; + + let out = drive_with_bincode(app, env).await.expect("drive failed"); + let mut buf = BytesMut::from(&out[..]); + let frame = LengthPrefixedProcessor::default() + .decode(&mut buf) + .expect("decode failed") + .expect("missing frame"); + let (resp_env, _) = BincodeSerializer + .deserialize::(&frame) + .expect("deserialize failed"); + + assert_eq!(resp_env.correlation_id, None); + let (echo, _) = Echo::from_bytes(&resp_env.payload).expect("decode echo failed"); + assert_eq!(echo, Echo(7)); +} + #[tokio::test] async fn multiple_frames_processed_in_sequence() { let app = WireframeApp::<_, _, TestEnvelope>::new() @@ -104,8 +144,8 @@ async fn multiple_frames_processed_in_sequence() { let msg_bytes = Echo(id).to_bytes().expect("encode failed"); let env = TestEnvelope { id: 1, - correlation_id: u64::from(id), - msg: msg_bytes, + correlation_id: Some(u64::from(id)), + payload: msg_bytes, }; let env_bytes = BincodeSerializer .serialize(&env) @@ -130,7 +170,7 @@ async fn multiple_frames_processed_in_sequence() { let (env1, _) = BincodeSerializer .deserialize::(&first) .expect("deserialize failed"); - let (echo1, _) = Echo::from_bytes(&env1.msg).expect("decode echo failed"); + let (echo1, _) = Echo::from_bytes(&env1.payload).expect("decode echo failed"); let second = LengthPrefixedProcessor::default() .decode(&mut buf) .expect("decode failed") @@ -138,9 +178,64 @@ async fn multiple_frames_processed_in_sequence() { let (env2, _) = BincodeSerializer .deserialize::(&second) .expect("deserialize failed"); - let (echo2, _) = Echo::from_bytes(&env2.msg).expect("decode echo failed"); - assert_eq!(env1.correlation_id, 1); - assert_eq!(env2.correlation_id, 2); + let (echo2, _) = Echo::from_bytes(&env2.payload).expect("decode echo failed"); + assert_eq!(env1.correlation_id, Some(1)); + assert_eq!(env2.correlation_id, Some(2)); assert_eq!(echo1, Echo(1)); assert_eq!(echo2, Echo(2)); } + +#[rstest] +#[case(None)] +#[case(Some(1))] +#[case(Some(2))] +#[tokio::test] +async fn single_frame_propagates_correlation_id(#[case] cid: Option) { + let app = WireframeApp::<_, _, TestEnvelope>::new() + .expect("failed to create app") + .frame_processor(LengthPrefixedProcessor::default()) + .route( + 1, + std::sync::Arc::new(|_: &TestEnvelope| Box::pin(async {})), + ) + .expect("route registration failed"); + + let msg_bytes = Echo(5).to_bytes().expect("encode failed"); + let env = TestEnvelope { + id: 1, + correlation_id: cid, + payload: msg_bytes, + }; + let env_bytes = BincodeSerializer.serialize(&env).expect("serialize failed"); + + let mut framed = BytesMut::new(); + LengthPrefixedProcessor::default() + .encode(&env_bytes, &mut framed) + .expect("encode failed"); + + let out = drive_with_frames(app, vec![framed.to_vec()]) + .await + .expect("drive failed"); + let mut buf = BytesMut::from(&out[..]); + let frame = LengthPrefixedProcessor::default() + .decode(&mut buf) + .expect("decode failed") + .expect("missing"); + let (resp, _) = BincodeSerializer + .deserialize::(&frame) + .expect("deserialize failed"); + + assert_eq!(resp.correlation_id, cid); +} + +#[test] +fn packet_from_parts_round_trips() { + let env = TestEnvelope { + id: 5, + correlation_id: Some(9), + payload: vec![1, 2, 3], + }; + let parts = env.clone().into_parts(); + let rebuilt = TestEnvelope::from_parts(parts); + assert_eq!(rebuilt, env); +} diff --git a/tests/world.rs b/tests/world.rs index cb32dd4a..06e11024 100644 --- a/tests/world.rs +++ b/tests/world.rs @@ -135,8 +135,8 @@ impl CorrelationWorld { pub async fn process(&mut self) { let cid = self.cid; let stream: FrameStream = Box::pin(try_stream! { - yield Envelope::new(1, cid, vec![1]); - yield Envelope::new(1, cid, vec![2]); + yield Envelope::new(1, Some(cid), vec![1]); + yield Envelope::new(1, Some(cid), vec![2]); }); let (queues, handle) = PushQueues::bounded(1, 1); let shutdown = CancellationToken::new(); @@ -150,6 +150,10 @@ impl CorrelationWorld { /// Panics if any frame has a `correlation_id` that does not match the /// expected value. pub fn verify(&self) { - assert!(self.frames.iter().all(|f| f.correlation_id() == self.cid)); + assert!( + self.frames + .iter() + .all(|f| f.correlation_id() == Some(self.cid)) + ); } }