diff --git a/docs/multi-packet-and-streaming-responses-design.md b/docs/multi-packet-and-streaming-responses-design.md index 9cbf79f7..5dbc8b5c 100644 --- a/docs/multi-packet-and-streaming-responses-design.md +++ b/docs/multi-packet-and-streaming-responses-design.md @@ -275,6 +275,10 @@ not hang. channel and stamps it onto every serialised frame. This preserves protocol invariants without requiring handlers to mutate frames post-creation and mirrors the message attribution strategy outlined in the capability roadmap. +- Implementation stores the expected identifier alongside a closure built from + the new `CorrelatableFrame` trait, ensuring frames can be stamped in a + generic actor without constraining other protocols. Debug builds assert the + stamped frame exposes the expected identifier so regressions fail fast. Debug-mode assertions must guard this stamping by checking `frame.correlation_id == request.correlation_id` before a frame is dispatched. diff --git a/docs/roadmap.md b/docs/roadmap.md index 15ca7d84..7550d8bb 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -207,13 +207,13 @@ stream. - [x] Emit tracing and metrics for each forwarded frame so streaming traffic remains visible to observability pipelines. - - [ ] Each sent frame must carry the correct `correlation_id` from the + - [x] Each sent frame must carry the correct `correlation_id` from the initial request. - - [ ] Capture the originating request's `correlation_id` before handing + - [x] Capture the originating request's `correlation_id` before handing control to the multi-packet dispatcher. - - [ ] Stamp the stored `correlation_id` onto every frame emitted from the + - [x] Stamp the stored `correlation_id` onto every frame emitted from the channel before it is queued for transmission. - - [ ] Guard against accidental omission by asserting in debug builds and + - [x] Guard against accidental omission by asserting in debug builds and covering the behaviour with targeted tests. - [ ] When the channel closes, send the end-of-stream marker frame. diff --git a/src/app/envelope.rs b/src/app/envelope.rs index 61931dda..d229fd58 100644 --- a/src/app/envelope.rs +++ b/src/app/envelope.rs @@ -6,7 +6,7 @@ //! deserialisation. See [`crate::app::builder::WireframeApp`] for how envelopes //! are used when registering routes. -use crate::message::Message; +use crate::{correlation::CorrelatableFrame, message::Message}; /// Envelope-like type used to wrap incoming and outgoing messages. /// @@ -101,6 +101,14 @@ impl Packet for Envelope { fn from_parts(parts: PacketParts) -> Self { parts.into() } } +impl CorrelatableFrame for Envelope { + fn correlation_id(&self) -> Option { self.correlation_id } + + fn set_correlation_id(&mut self, correlation_id: Option) { + self.correlation_id = correlation_id; + } +} + impl PacketParts { /// Construct a new set of packet parts. #[must_use] diff --git a/src/connection.rs b/src/connection.rs index 3026d3e9..438843f4 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -46,6 +46,7 @@ impl Drop for ActiveConnection { pub fn active_connection_count() -> u64 { ACTIVE_CONNECTIONS.load(Ordering::Relaxed) } use crate::{ + correlation::CorrelatableFrame, fairness::FairnessTracker, hooks::{ConnectionContext, ProtocolHooks}, push::{FrameLike, PushHandle, PushQueues}, @@ -120,7 +121,7 @@ pub struct ConnectionActor { /// Optional multi-packet channel drained after low-priority frames. /// This preserves fairness with queued sources. /// The actor emits the protocol terminator when the sender closes the channel. - multi_packet: Option>, + multi_packet: MultiPacketContext, shutdown: CancellationToken, counter: Option, hooks: ProtocolHooks, @@ -136,6 +137,56 @@ struct DrainContext<'a, F> { state: &'a mut ActorState, } +/// Multi-packet correlation stamping state. +/// +/// Tracks the active receiver and how frames should be stamped before emission. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum MultiPacketStamp { + /// Stamping is disabled because no multi-packet channel is active. + Disabled, + /// Stamping is enabled and frames are stamped with the provided identifier. + Enabled(Option), +} + +/// Multi-packet channel state tracking the active receiver and stamping config. +struct MultiPacketContext { + channel: Option>, + stamp: MultiPacketStamp, +} + +impl MultiPacketContext { + const fn new() -> Self { + Self { + channel: None, + stamp: MultiPacketStamp::Disabled, + } + } + + fn install(&mut self, channel: Option>, stamp: MultiPacketStamp) { + debug_assert_eq!( + channel.is_some(), + matches!(stamp, MultiPacketStamp::Enabled(_)), + "channel presence must match stamp: channel is Some iff stamp is \ + MultiPacketStamp::Enabled(...)", + ); + self.channel = channel; + self.stamp = stamp; + } + + fn clear(&mut self) { + self.channel = None; + self.stamp = MultiPacketStamp::Disabled; + } + + fn channel_mut(&mut self) -> Option<&mut mpsc::Receiver> { self.channel.as_mut() } + + fn take_channel(&mut self) -> Option> { self.channel.take() } + + fn stamp(&self) -> MultiPacketStamp { self.stamp } + + fn is_active(&self) -> bool { self.channel.is_some() } +} + /// Queue variants processed by the connection actor. #[derive(Clone, Copy)] enum QueueKind { @@ -146,7 +197,7 @@ enum QueueKind { impl ConnectionActor where - F: FrameLike, + F: FrameLike + CorrelatableFrame, E: std::fmt::Debug, { /// Create a new `ConnectionActor` from the provided components. @@ -197,7 +248,7 @@ where high_rx: Some(queues.high_priority_rx), low_rx: Some(queues.low_priority_rx), response, - multi_packet: None, + multi_packet: MultiPacketContext::new(), shutdown, counter: Some(counter), hooks, @@ -221,7 +272,7 @@ where /// Set or replace the current streaming response. pub fn set_response(&mut self, stream: Option>) { debug_assert!( - self.multi_packet.is_none(), + !self.multi_packet.is_active(), concat!( "ConnectionActor invariant violated: cannot set response while a ", "multi_packet channel is active" @@ -239,7 +290,78 @@ where "response stream is active" ), ); - self.multi_packet = channel; + let stamp = if channel.is_some() { + MultiPacketStamp::Enabled(None) + } else { + MultiPacketStamp::Disabled + }; + self.multi_packet.install(channel, stamp); + } + + /// Set or replace the current multi-packet response channel and stamp correlation identifiers. + /// + /// # Examples + /// + /// ```no_run + /// # use tokio::sync::mpsc; + /// # use tokio_util::sync::CancellationToken; + /// # use wireframe::{ConnectionActor, push::PushQueues}; + /// # let (queues, handle) = PushQueues::::builder() + /// # .high_capacity(1) + /// # .low_capacity(1) + /// # .build() + /// # .expect("failed to build PushQueues"); + /// # let shutdown = CancellationToken::new(); + /// # let mut actor = ConnectionActor::new(queues, handle, None, shutdown); + /// # let (_tx, rx) = mpsc::channel(4); + /// actor.set_multi_packet_with_correlation(Some(rx), Some(7)); + /// ``` + pub fn set_multi_packet_with_correlation( + &mut self, + channel: Option>, + correlation_id: Option, + ) { + debug_assert!( + self.response.is_none(), + concat!( + "ConnectionActor invariant violated: cannot set multi_packet while a ", + "response stream is active" + ), + ); + let stamp = if channel.is_some() { + MultiPacketStamp::Enabled(correlation_id) + } else { + MultiPacketStamp::Disabled + }; + self.multi_packet.install(channel, stamp); + } + + fn clear_multi_packet(&mut self) { self.multi_packet.clear(); } + + fn apply_multi_packet_correlation(&mut self, frame: &mut F) { + match self.multi_packet.stamp() { + MultiPacketStamp::Enabled(Some(expected)) => { + frame.set_correlation_id(Some(expected)); + debug_assert_eq!( + frame.correlation_id(), + Some(expected), + "multi-packet frame correlation mismatch: expected={:?}, got={:?}", + Some(expected), + frame.correlation_id(), + ); + } + MultiPacketStamp::Enabled(None) => { + frame.set_correlation_id(None); + debug_assert!( + frame.correlation_id().is_none(), + "multi-packet frame correlation unexpectedly present: got={:?}", + frame.correlation_id(), + ); + } + MultiPacketStamp::Disabled => { + unreachable!("multi-packet correlation invoked without configuration"); + } + } } /// Replace the low-priority queue used for tests. @@ -271,11 +393,11 @@ where } debug_assert!( - usize::from(self.response.is_some()) + usize::from(self.multi_packet.is_some()) <= 1, + usize::from(self.response.is_some()) + usize::from(self.multi_packet.is_active()) <= 1, "ConnectionActor invariant violated: at most one of response or multi_packet may be \ active" ); - let mut state = ActorState::new(self.response.is_some(), self.multi_packet.is_some()); + let mut state = ActorState::new(self.response.is_some(), self.multi_packet.is_active()); while !state.is_done() { self.poll_sources(&mut state, out).await?; @@ -302,22 +424,17 @@ where async fn next_event(&mut self, state: &ActorState) -> Event { let high_available = self.high_rx.is_some(); let low_available = self.low_rx.is_some(); - let multi_available = self.multi_packet.is_some() && !state.is_shutting_down(); + let multi_available = self.multi_packet.is_active() && !state.is_shutting_down(); let resp_available = self.response.is_some() && !state.is_shutting_down(); tokio::select! { biased; () = Self::await_shutdown(self.shutdown.clone()), if state.is_active() => Event::Shutdown, - res = Self::poll_queue(self.high_rx.as_mut()), if high_available => Event::High(res), - res = Self::poll_queue(self.low_rx.as_mut()), if low_available => Event::Low(res), - - res = Self::poll_queue(self.multi_packet.as_mut()), if multi_available => Event::MultiPacket(res), - + res = Self::poll_queue(self.multi_packet.channel_mut()), if multi_available => Event::MultiPacket(res), res = Self::poll_response(self.response.as_mut()), if resp_available => Event::Response(res), - else => Event::Idle, } } @@ -372,7 +489,16 @@ where let DrainContext { out, state } = ctx; match res { Some(frame) => { - self.process_frame_with_hooks_and_metrics(frame, out); + match kind { + QueueKind::Multi + if matches!(self.multi_packet.stamp(), MultiPacketStamp::Enabled(_)) => + { + self.emit_multi_packet_frame(frame, out); + } + _ => { + self.process_frame_with_hooks_and_metrics(frame, out); + } + } match kind { QueueKind::High => self.after_high(out, state), QueueKind::Low | QueueKind::Multi => self.after_low(), @@ -403,10 +529,16 @@ where self.process_queue(QueueKind::Multi, res, DrainContext { out, state }); } + fn emit_multi_packet_frame(&mut self, frame: F, out: &mut Vec) { + let mut frame = frame; + self.apply_multi_packet_correlation(&mut frame); + self.process_frame_with_hooks_and_metrics(frame, out); + } + /// Handle a closed multi-packet channel by emitting the protocol terminator and notifying /// hooks. fn handle_multi_packet_closed(&mut self, state: &mut ActorState, out: &mut Vec) { - let rx = self.multi_packet.take(); + let rx = self.multi_packet.take_channel(); self.handle_multi_packet_closed_with(rx, state, out); } @@ -428,9 +560,10 @@ where } state.mark_closed(); if let Some(frame) = self.hooks.stream_end_frame(&mut self.ctx) { - self.process_frame_with_hooks_and_metrics(frame, out); + self.emit_multi_packet_frame(frame, out); self.after_low(); } + self.clear_multi_packet(); self.hooks.on_command_end(&mut self.ctx); } @@ -480,9 +613,10 @@ where if let Some(rx) = &mut self.low_rx { rx.close(); } - if let Some(mut rx) = self.multi_packet.take() { + if let Some(mut rx) = self.multi_packet.take_channel() { rx.close(); state.mark_closed(); + self.clear_multi_packet(); } if self.response.take().is_some() { state.mark_closed(); @@ -522,16 +656,10 @@ where fn try_opportunistic_drain(&mut self, kind: QueueKind, ctx: DrainContext<'_, F>) -> bool { let DrainContext { out, state } = ctx; match kind { - QueueKind::High => { - debug_assert!( - false, - concat!( - "try_opportunistic_drain(High) is unsupported; ", - "High is handled by biased polling", - ), - ); - false - } + QueueKind::High => unreachable!(concat!( + "try_opportunistic_drain(High) is unsupported; ", + "High is handled by biased polling", + )), QueueKind::Low => { let res = match self.low_rx.as_mut() { Some(receiver) => receiver.try_recv(), @@ -552,23 +680,21 @@ where } } QueueKind::Multi => { - let Some(mut rx) = self.multi_packet.take() else { - return false; + let result = match self.multi_packet.channel_mut() { + Some(rx) => rx.try_recv(), + None => return false, }; - match rx.try_recv() { + match result { Ok(frame) => { - self.process_frame_with_hooks_and_metrics(frame, out); + self.emit_multi_packet_frame(frame, out); self.after_low(); - self.multi_packet = Some(rx); true } - Err(TryRecvError::Empty) => { - self.multi_packet = Some(rx); - false - } + Err(TryRecvError::Empty) => false, Err(TryRecvError::Disconnected) => { - self.handle_multi_packet_closed_with(Some(rx), state, out); + let rx = self.multi_packet.take_channel(); + self.handle_multi_packet_closed_with(rx, state, out); false } } diff --git a/src/connection/test_support.rs b/src/connection/test_support.rs index 8ac198a7..26f1f754 100644 --- a/src/connection/test_support.rs +++ b/src/connection/test_support.rs @@ -96,7 +96,7 @@ impl ActorHarness { /// Returns `true` when the multi-packet queue is still available. #[must_use] - pub fn has_multi_queue(&self) -> bool { self.actor.multi_packet.is_some() } + pub fn has_multi_queue(&self) -> bool { self.actor.multi_packet.is_active() } /// Process a multi-packet poll result. pub fn process_multi_packet(&mut self, res: Option) { @@ -180,3 +180,46 @@ impl ActorStateHarness { pub async fn poll_queue_next(rx: Option<&mut mpsc::Receiver>) -> Option { ConnectionActor::::poll_queue(rx).await } + +#[cfg(test)] +mod tests { + use tokio::sync::mpsc; + + use super::*; + + #[test] + fn has_multi_queue_false_by_default() { + let harness = ActorHarness::new().expect("build ActorHarness"); + assert!( + !harness.has_multi_queue(), + "multi-packet queue should start inactive" + ); + } + + #[test] + fn has_multi_queue_true_after_install() { + let mut harness = ActorHarness::new().expect("build ActorHarness"); + let (_tx, rx) = mpsc::channel(1); + harness.set_multi_queue(Some(rx)); + assert!( + harness.has_multi_queue(), + "multi-packet queue should be active after install" + ); + } + + #[test] + fn has_multi_queue_false_after_clear() { + let mut harness = ActorHarness::new().expect("build ActorHarness"); + let (_tx, rx) = mpsc::channel(1); + harness.set_multi_queue(Some(rx)); + assert!( + harness.has_multi_queue(), + "multi-packet queue should be active after install" + ); + harness.set_multi_queue(None); + assert!( + !harness.has_multi_queue(), + "multi-packet queue should be inactive after clear" + ); + } +} diff --git a/src/correlation.rs b/src/correlation.rs new file mode 100644 index 00000000..479adb97 --- /dev/null +++ b/src/correlation.rs @@ -0,0 +1,63 @@ +//! Traits for working with correlation identifiers on frames. +//! +//! `CorrelatableFrame` abstracts over frame types that carry an optional +//! correlation identifier, allowing generic components such as the connection +//! actor to stamp or inspect identifiers without knowing the concrete frame +//! representation. + +/// Access and mutate correlation identifiers on frames. +pub trait CorrelatableFrame { + /// Return the correlation identifier associated with this frame, if any. + fn correlation_id(&self) -> Option; + + /// Set or clear the correlation identifier. + fn set_correlation_id(&mut self, correlation_id: Option); +} + +impl CorrelatableFrame for u8 { + fn correlation_id(&self) -> Option { None } + + fn set_correlation_id(&mut self, _correlation_id: Option) {} +} + +impl CorrelatableFrame for Vec { + fn correlation_id(&self) -> Option { None } + + fn set_correlation_id(&mut self, _correlation_id: Option) {} +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + use crate::app::Envelope; + + #[rstest] + #[case(None)] + #[case(Some(27))] + fn envelope_correlation_round_trip(#[case] initial: Option) { + let mut frame = Envelope::new(7, initial, vec![1, 2, 3]); + assert_eq!(frame.correlation_id(), initial); + + frame.set_correlation_id(Some(99)); + assert_eq!(frame.correlation_id(), Some(99)); + + frame.set_correlation_id(None); + assert_eq!(frame.correlation_id(), None); + } + + #[rstest] + #[case::byte(0u8)] + #[case::buffer(Vec::::new())] + fn noop_implementations_ignore_correlation(#[case] mut frame: T) + where + T: CorrelatableFrame, + { + assert_eq!(frame.correlation_id(), None); + frame.set_correlation_id(Some(42)); + assert_eq!(frame.correlation_id(), None); + frame.set_correlation_id(None); + assert_eq!(frame.correlation_id(), None); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3845c5e0..44cca9ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ pub use app::error::Result; pub mod serializer; pub use serializer::{BincodeSerializer, Serializer}; pub mod connection; +pub mod correlation; pub mod extractor; mod fairness; pub mod frame; @@ -28,6 +29,7 @@ pub mod server; pub mod session; pub use connection::ConnectionActor; +pub use correlation::CorrelatableFrame; pub use hooks::{ConnectionContext, ProtocolHooks, WireframeProtocol}; pub use metrics::{CONNECTIONS_ACTIVE, Direction, ERRORS_TOTAL, FRAMES_PROCESSED}; pub use response::{FrameStream, Response, WireframeError}; diff --git a/tests/correlation_id.rs b/tests/correlation_id.rs index 586d8434..63b39313 100644 --- a/tests/correlation_id.rs +++ b/tests/correlation_id.rs @@ -1,10 +1,14 @@ #![cfg(not(loom))] //! Tests for `correlation_id` propagation in streaming responses. use async_stream::try_stream; +use rstest::rstest; +use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use wireframe::{ - app::{Envelope, Packet}, + CorrelatableFrame, + app::Envelope, connection::ConnectionActor, + hooks::{ConnectionContext, ProtocolHooks}, push::PushQueues, response::FrameStream, }; @@ -28,3 +32,74 @@ async fn stream_frames_carry_request_correlation_id() { actor.run(&mut out).await.expect("actor run failed"); assert!(out.iter().all(|e| e.correlation_id() == Some(cid))); } + +async fn run_multi_packet_channel( + request_correlation: Option, + frame_correlations: &[Option], + hooks: ProtocolHooks, +) -> Vec { + let capacity = frame_correlations.len().max(1); + let (tx, rx) = mpsc::channel(capacity); + for (idx, correlation) in frame_correlations.iter().enumerate() { + let marker = (idx + 1) as u64; + let payload = marker.to_le_bytes().to_vec(); + tx.send(Envelope::new(1, *correlation, payload)) + .await + .expect("send frame"); + } + drop(tx); + + let (queues, handle) = PushQueues::::builder() + .high_capacity(2) + .low_capacity(2) + .unlimited() + .build() + .expect("failed to build PushQueues"); + let shutdown = CancellationToken::new(); + let mut actor: ConnectionActor = + ConnectionActor::with_hooks(queues, handle, None, shutdown, hooks); + actor.set_multi_packet_with_correlation(Some(rx), request_correlation); + + let mut out = Vec::new(); + actor.run(&mut out).await.expect("actor run failed"); + out +} + +#[rstest] +#[case::stamps_request(Some(7), vec![None, Some(99)], vec![Some(7), Some(7)])] +#[case::clears_when_absent(None, vec![None, Some(13)], vec![None, None])] +#[case::preserves_matching(Some(17), vec![Some(17), Some(17)], vec![Some(17), Some(17)])] +#[tokio::test] +async fn multi_packet_frames_apply_expected_correlation( + #[case] request: Option, + #[case] initial: Vec>, + #[case] expected: Vec>, +) { + let frames = run_multi_packet_channel(request, &initial, ProtocolHooks::default()).await; + let correlations: Vec> = frames + .iter() + .map(CorrelatableFrame::correlation_id) + .collect(); + assert_eq!(correlations, expected); +} + +#[rstest] +#[case::terminator_stamped(Some(11), Some(11))] +#[case::terminator_cleared(None, None)] +#[tokio::test] +async fn multi_packet_terminator_applies_correlation( + #[case] request: Option, + #[case] expected: Option, +) { + let hooks = ProtocolHooks { + stream_end: Some(Box::new(|_ctx: &mut ConnectionContext| { + Some(Envelope::new(255, None, vec![])) + })), + ..ProtocolHooks::default() + }; + + let frames = run_multi_packet_channel(request, &[], hooks).await; + assert_eq!(frames.len(), 1, "terminator frame missing"); + let terminator = frames.last().expect("terminator frame missing"); + assert_eq!(terminator.correlation_id(), expected); +} diff --git a/tests/features/correlation_id.feature b/tests/features/correlation_id.feature index bdb56b19..d14a2938 100644 --- a/tests/features/correlation_id.feature +++ b/tests/features/correlation_id.feature @@ -3,3 +3,13 @@ Feature: Multi-packet response correlation Given a correlation id 7 When a stream of frames is processed Then each emitted frame uses correlation id 7 + + Scenario: Multi-packet responses reuse the request correlation id + Given a correlation id 11 + When a multi-packet channel emits frames + Then each emitted frame uses correlation id 11 + + Scenario: Multi-packet responses clear correlation ids without a request id + Given no correlation id + When a multi-packet channel emits frames + Then each emitted frame has no correlation id diff --git a/tests/steps/correlation_steps.rs b/tests/steps/correlation_steps.rs index 95b41c80..bd58da0a 100644 --- a/tests/steps/correlation_steps.rs +++ b/tests/steps/correlation_steps.rs @@ -4,13 +4,25 @@ use cucumber::{given, then, when}; use crate::world::CorrelationWorld; #[given(expr = "a correlation id {int}")] -fn given_cid(world: &mut CorrelationWorld, id: u64) { world.set_cid(id); } +fn given_cid(world: &mut CorrelationWorld, id: u64) { world.set_expected(Some(id)); } + +#[given("no correlation id")] +fn given_no_correlation(world: &mut CorrelationWorld) { world.set_expected(None); } #[when("a stream of frames is processed")] async fn when_process(world: &mut CorrelationWorld) { world.process().await; } +#[when("a multi-packet channel emits frames")] +async fn when_process_multi(world: &mut CorrelationWorld) { world.process_multi().await; } + #[then(expr = "each emitted frame uses correlation id {int}")] fn then_verify(world: &mut CorrelationWorld, id: u64) { - assert_eq!(world.cid(), id); + assert_eq!(world.expected(), Some(id)); + world.verify(); +} + +#[then("each emitted frame has no correlation id")] +fn then_verify_absent(world: &mut CorrelationWorld) { + assert_eq!(world.expected(), None); world.verify(); } diff --git a/tests/world.rs b/tests/world.rs index 8520dd84..dd5c2c5b 100644 --- a/tests/world.rs +++ b/tests/world.rs @@ -9,7 +9,10 @@ use std::{net::SocketAddr, sync::Arc}; use async_stream::try_stream; use cucumber::World; -use tokio::{net::TcpStream, sync::oneshot}; +use tokio::{ + net::TcpStream, + sync::{mpsc, oneshot}, +}; use tokio_util::sync::CancellationToken; use wireframe::{ app::{Envelope, Packet}, @@ -137,22 +140,24 @@ impl PanicWorld { #[derive(Debug, Default, World)] pub struct CorrelationWorld { - cid: u64, + expected: Option, frames: Vec, } impl CorrelationWorld { - pub fn set_cid(&mut self, cid: u64) { self.cid = cid; } + pub fn set_expected(&mut self, expected: Option) { self.expected = expected; } #[must_use] - pub fn cid(&self) -> u64 { self.cid } + pub fn expected(&self) -> Option { self.expected } /// Run the connection actor and collect frames for later verification. /// /// # Panics /// Panics if the actor fails to run successfully. pub async fn process(&mut self) { - let cid = self.cid; + let cid = self + .expected + .expect("streaming scenario requires a correlation id"); let stream: FrameStream = Box::pin(try_stream! { yield Envelope::new(1, Some(cid), vec![1]); yield Envelope::new(1, Some(cid), vec![2]); @@ -163,16 +168,42 @@ impl CorrelationWorld { actor.run(&mut self.frames).await.expect("actor run failed"); } - /// Verify that all received frames carry the expected correlation ID. + /// Run the connection actor for a multi-packet channel and collect frames. + /// + /// # Panics + /// Panics if sending to the channel or running the actor fails. + pub async fn process_multi(&mut self) { + let expected = self.expected; + let (tx, rx) = mpsc::channel(4); + tx.send(Envelope::new(1, None, vec![1])) + .await + .expect("send frame"); + tx.send(Envelope::new(1, Some(99), vec![2])) + .await + .expect("send frame"); + drop(tx); + + let (queues, handle) = build_small_queues::(); + let shutdown = CancellationToken::new(); + let mut actor: ConnectionActor = + ConnectionActor::new(queues, handle, None, shutdown); + actor.set_multi_packet_with_correlation(Some(rx), expected); + actor.run(&mut self.frames).await.expect("actor run failed"); + } + + /// Verify that all received frames respect the configured correlation expectation. /// /// # Panics - /// Panics if any frame has a `correlation_id` that does not match `self.cid`. + /// Panics if any frame violates the stored correlation expectation. pub fn verify(&self) { - assert!( - self.frames - .iter() - .all(|f| f.correlation_id() == Some(self.cid)) - ); + match self.expected { + Some(cid) => { + assert!(self.frames.iter().all(|f| f.correlation_id() == Some(cid))); + } + None => { + assert!(self.frames.iter().all(|f| f.correlation_id().is_none())); + } + } } }