Skip to content
4 changes: 4 additions & 0 deletions docs/multi-packet-and-streaming-responses-design.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ not hang.
channel and stamps it onto every serialised frame. This preserves protocol
invariants without requiring handlers to mutate frames post-creation and
mirrors the message attribution strategy outlined in the capability roadmap.
- Implementation stores the expected identifier alongside a closure built from
the new `CorrelatableFrame` trait, ensuring frames can be stamped in a
generic actor without constraining other protocols. Debug builds assert the
stamped frame exposes the expected identifier so regressions fail fast.

Debug-mode assertions must guard this stamping by checking
`frame.correlation_id == request.correlation_id` before a frame is dispatched.
Expand Down
8 changes: 4 additions & 4 deletions docs/roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,13 @@ stream.
- [x] Emit tracing and metrics for each forwarded frame so streaming
traffic remains visible to observability pipelines.

- [ ] Each sent frame must carry the correct `correlation_id` from the
- [x] Each sent frame must carry the correct `correlation_id` from the
initial request.
- [ ] Capture the originating request's `correlation_id` before handing
- [x] Capture the originating request's `correlation_id` before handing
control to the multi-packet dispatcher.
- [ ] Stamp the stored `correlation_id` onto every frame emitted from the
- [x] Stamp the stored `correlation_id` onto every frame emitted from the
channel before it is queued for transmission.
- [ ] Guard against accidental omission by asserting in debug builds and
- [x] Guard against accidental omission by asserting in debug builds and
covering the behaviour with targeted tests.

- [ ] When the channel closes, send the end-of-stream marker frame.
Expand Down
10 changes: 9 additions & 1 deletion src/app/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! deserialisation. See [`crate::app::builder::WireframeApp`] for how envelopes
//! are used when registering routes.

use crate::message::Message;
use crate::{correlation::CorrelatableFrame, message::Message};

/// Envelope-like type used to wrap incoming and outgoing messages.
///
Expand Down Expand Up @@ -101,6 +101,14 @@ impl Packet for Envelope {
fn from_parts(parts: PacketParts) -> Self { parts.into() }
}

impl CorrelatableFrame for Envelope {
Comment thread
leynos marked this conversation as resolved.
fn correlation_id(&self) -> Option<u64> { self.correlation_id }

fn set_correlation_id(&mut self, correlation_id: Option<u64>) {
self.correlation_id = correlation_id;
}
}

impl PacketParts {
/// Construct a new set of packet parts.
#[must_use]
Expand Down
202 changes: 164 additions & 38 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl Drop for ActiveConnection {
pub fn active_connection_count() -> u64 { ACTIVE_CONNECTIONS.load(Ordering::Relaxed) }

use crate::{
correlation::CorrelatableFrame,
fairness::FairnessTracker,
hooks::{ConnectionContext, ProtocolHooks},
push::{FrameLike, PushHandle, PushQueues},
Expand Down Expand Up @@ -120,7 +121,7 @@ pub struct ConnectionActor<F, E> {
/// Optional multi-packet channel drained after low-priority frames.
/// This preserves fairness with queued sources.
/// The actor emits the protocol terminator when the sender closes the channel.
multi_packet: Option<mpsc::Receiver<F>>,
multi_packet: MultiPacketContext<F>,
shutdown: CancellationToken,
counter: Option<ActiveConnection>,
hooks: ProtocolHooks<F, E>,
Expand All @@ -136,6 +137,56 @@ struct DrainContext<'a, F> {
state: &'a mut ActorState,
}

/// Multi-packet correlation stamping state.
///
/// Tracks the active receiver and how frames should be stamped before emission.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum MultiPacketStamp {
/// Stamping is disabled because no multi-packet channel is active.
Disabled,
/// Stamping is enabled and frames are stamped with the provided identifier.
Enabled(Option<u64>),
}

/// Multi-packet channel state tracking the active receiver and stamping config.
struct MultiPacketContext<F> {
channel: Option<mpsc::Receiver<F>>,
stamp: MultiPacketStamp,
}

impl<F> MultiPacketContext<F> {
const fn new() -> Self {
Self {
channel: None,
stamp: MultiPacketStamp::Disabled,
}
}

fn install(&mut self, channel: Option<mpsc::Receiver<F>>, stamp: MultiPacketStamp) {
debug_assert_eq!(
channel.is_some(),
matches!(stamp, MultiPacketStamp::Enabled(_)),
"channel presence must match stamp: channel is Some iff stamp is \
MultiPacketStamp::Enabled(...)",
);
self.channel = channel;
self.stamp = stamp;
}

fn clear(&mut self) {
self.channel = None;
self.stamp = MultiPacketStamp::Disabled;
}

fn channel_mut(&mut self) -> Option<&mut mpsc::Receiver<F>> { self.channel.as_mut() }

fn take_channel(&mut self) -> Option<mpsc::Receiver<F>> { self.channel.take() }

fn stamp(&self) -> MultiPacketStamp { self.stamp }

fn is_active(&self) -> bool { self.channel.is_some() }
}

/// Queue variants processed by the connection actor.
#[derive(Clone, Copy)]
enum QueueKind {
Expand All @@ -146,7 +197,7 @@ enum QueueKind {

impl<F, E> ConnectionActor<F, E>
where
F: FrameLike,
F: FrameLike + CorrelatableFrame,
E: std::fmt::Debug,
{
/// Create a new `ConnectionActor` from the provided components.
Expand Down Expand Up @@ -197,7 +248,7 @@ where
high_rx: Some(queues.high_priority_rx),
low_rx: Some(queues.low_priority_rx),
response,
multi_packet: None,
multi_packet: MultiPacketContext::new(),
shutdown,
counter: Some(counter),
hooks,
Expand All @@ -221,7 +272,7 @@ where
/// Set or replace the current streaming response.
pub fn set_response(&mut self, stream: Option<FrameStream<F, E>>) {
debug_assert!(
self.multi_packet.is_none(),
!self.multi_packet.is_active(),
concat!(
"ConnectionActor invariant violated: cannot set response while a ",
"multi_packet channel is active"
Expand All @@ -239,7 +290,78 @@ where
"response stream is active"
),
);
self.multi_packet = channel;
let stamp = if channel.is_some() {
MultiPacketStamp::Enabled(None)
} else {
MultiPacketStamp::Disabled
};
self.multi_packet.install(channel, stamp);
}

/// Set or replace the current multi-packet response channel and stamp correlation identifiers.
///
/// # Examples
///
/// ```no_run
/// # use tokio::sync::mpsc;
/// # use tokio_util::sync::CancellationToken;
/// # use wireframe::{ConnectionActor, push::PushQueues};
/// # let (queues, handle) = PushQueues::<u8>::builder()
/// # .high_capacity(1)
/// # .low_capacity(1)
/// # .build()
/// # .expect("failed to build PushQueues");
/// # let shutdown = CancellationToken::new();
/// # let mut actor = ConnectionActor::new(queues, handle, None, shutdown);
/// # let (_tx, rx) = mpsc::channel(4);
/// actor.set_multi_packet_with_correlation(Some(rx), Some(7));
/// ```
pub fn set_multi_packet_with_correlation(
&mut self,
channel: Option<mpsc::Receiver<F>>,
correlation_id: Option<u64>,
) {
debug_assert!(
self.response.is_none(),
concat!(
"ConnectionActor invariant violated: cannot set multi_packet while a ",
"response stream is active"
),
);
let stamp = if channel.is_some() {
MultiPacketStamp::Enabled(correlation_id)
} else {
MultiPacketStamp::Disabled
};
self.multi_packet.install(channel, stamp);
}

fn clear_multi_packet(&mut self) { self.multi_packet.clear(); }

fn apply_multi_packet_correlation(&mut self, frame: &mut F) {
match self.multi_packet.stamp() {
MultiPacketStamp::Enabled(Some(expected)) => {
frame.set_correlation_id(Some(expected));
debug_assert_eq!(
frame.correlation_id(),
Some(expected),
"multi-packet frame correlation mismatch: expected={:?}, got={:?}",
Some(expected),
frame.correlation_id(),
);
}
MultiPacketStamp::Enabled(None) => {
frame.set_correlation_id(None);
debug_assert!(
frame.correlation_id().is_none(),
"multi-packet frame correlation unexpectedly present: got={:?}",
frame.correlation_id(),
);
}
MultiPacketStamp::Disabled => {
unreachable!("multi-packet correlation invoked without configuration");
}
}
}

/// Replace the low-priority queue used for tests.
Expand Down Expand Up @@ -271,11 +393,11 @@ where
}

debug_assert!(
usize::from(self.response.is_some()) + usize::from(self.multi_packet.is_some()) <= 1,
usize::from(self.response.is_some()) + usize::from(self.multi_packet.is_active()) <= 1,
"ConnectionActor invariant violated: at most one of response or multi_packet may be \
active"
);
let mut state = ActorState::new(self.response.is_some(), self.multi_packet.is_some());
let mut state = ActorState::new(self.response.is_some(), self.multi_packet.is_active());

while !state.is_done() {
self.poll_sources(&mut state, out).await?;
Expand All @@ -302,22 +424,17 @@ where
async fn next_event(&mut self, state: &ActorState) -> Event<F, E> {
let high_available = self.high_rx.is_some();
let low_available = self.low_rx.is_some();
let multi_available = self.multi_packet.is_some() && !state.is_shutting_down();
let multi_available = self.multi_packet.is_active() && !state.is_shutting_down();
let resp_available = self.response.is_some() && !state.is_shutting_down();

tokio::select! {
biased;

() = Self::await_shutdown(self.shutdown.clone()), if state.is_active() => Event::Shutdown,

res = Self::poll_queue(self.high_rx.as_mut()), if high_available => Event::High(res),

res = Self::poll_queue(self.low_rx.as_mut()), if low_available => Event::Low(res),

res = Self::poll_queue(self.multi_packet.as_mut()), if multi_available => Event::MultiPacket(res),

res = Self::poll_queue(self.multi_packet.channel_mut()), if multi_available => Event::MultiPacket(res),
res = Self::poll_response(self.response.as_mut()), if resp_available => Event::Response(res),

else => Event::Idle,
}
}
Expand Down Expand Up @@ -372,7 +489,16 @@ where
let DrainContext { out, state } = ctx;
match res {
Some(frame) => {
self.process_frame_with_hooks_and_metrics(frame, out);
match kind {
QueueKind::Multi
if matches!(self.multi_packet.stamp(), MultiPacketStamp::Enabled(_)) =>
{
self.emit_multi_packet_frame(frame, out);
}
_ => {
self.process_frame_with_hooks_and_metrics(frame, out);
}
}
match kind {
QueueKind::High => self.after_high(out, state),
QueueKind::Low | QueueKind::Multi => self.after_low(),
Expand Down Expand Up @@ -403,10 +529,16 @@ where
self.process_queue(QueueKind::Multi, res, DrainContext { out, state });
}

fn emit_multi_packet_frame(&mut self, frame: F, out: &mut Vec<F>) {
let mut frame = frame;
self.apply_multi_packet_correlation(&mut frame);
self.process_frame_with_hooks_and_metrics(frame, out);
}

/// Handle a closed multi-packet channel by emitting the protocol terminator and notifying
/// hooks.
fn handle_multi_packet_closed(&mut self, state: &mut ActorState, out: &mut Vec<F>) {
let rx = self.multi_packet.take();
let rx = self.multi_packet.take_channel();
self.handle_multi_packet_closed_with(rx, state, out);
}

Expand All @@ -428,9 +560,10 @@ where
}
state.mark_closed();
if let Some(frame) = self.hooks.stream_end_frame(&mut self.ctx) {
self.process_frame_with_hooks_and_metrics(frame, out);
self.emit_multi_packet_frame(frame, out);
self.after_low();
}
self.clear_multi_packet();
self.hooks.on_command_end(&mut self.ctx);
}

Expand Down Expand Up @@ -480,9 +613,10 @@ where
if let Some(rx) = &mut self.low_rx {
rx.close();
}
if let Some(mut rx) = self.multi_packet.take() {
if let Some(mut rx) = self.multi_packet.take_channel() {
rx.close();
state.mark_closed();
self.clear_multi_packet();
}
if self.response.take().is_some() {
state.mark_closed();
Expand Down Expand Up @@ -522,16 +656,10 @@ where
fn try_opportunistic_drain(&mut self, kind: QueueKind, ctx: DrainContext<'_, F>) -> bool {
let DrainContext { out, state } = ctx;
match kind {
QueueKind::High => {
debug_assert!(
false,
concat!(
"try_opportunistic_drain(High) is unsupported; ",
"High is handled by biased polling",
),
);
false
}
QueueKind::High => unreachable!(concat!(
"try_opportunistic_drain(High) is unsupported; ",
"High is handled by biased polling",
)),
QueueKind::Low => {
let res = match self.low_rx.as_mut() {
Some(receiver) => receiver.try_recv(),
Expand All @@ -552,23 +680,21 @@ where
}
}
QueueKind::Multi => {
let Some(mut rx) = self.multi_packet.take() else {
return false;
let result = match self.multi_packet.channel_mut() {
Some(rx) => rx.try_recv(),
None => return false,
};

match rx.try_recv() {
match result {
Ok(frame) => {
self.process_frame_with_hooks_and_metrics(frame, out);
self.emit_multi_packet_frame(frame, out);
self.after_low();
self.multi_packet = Some(rx);
true
}
Err(TryRecvError::Empty) => {
self.multi_packet = Some(rx);
false
}
Err(TryRecvError::Empty) => false,
Err(TryRecvError::Disconnected) => {
self.handle_multi_packet_closed_with(Some(rx), state, out);
let rx = self.multi_packet.take_channel();
self.handle_multi_packet_closed_with(rx, state, out);
false
}
}
Expand Down
Loading
Loading