Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/wireframe-testing-crate.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ let bytes = drive_with_bincode(app, Ping(1)).await.unwrap();
assert_eq!(bytes, [0, 1]);
```

### Helper macros

Two small macros, `push_expect!` and `recv_expect!`, reduce boilerplate in test
code. They await a future and panic with a message including the call site when
the future resolves to an error.

```rust
push_expect!(handle.push_high_priority(42));
let (_, frame) = recv_expect!(queues.recv());
```

## Example Usage

```rust
Expand Down
7 changes: 0 additions & 7 deletions examples/metadata_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,6 @@ impl FrameMetadata for HeaderSerializer {
#[derive(bincode::Decode, bincode::Encode)]
struct Ping;

#[derive(bincode::Decode, bincode::Encode)]
#[expect(
dead_code,
reason = "placeholder for demonstration of metadata routing"
)]
struct Pong;

#[tokio::main]
async fn main() -> io::Result<()> {
let app = WireframeApp::new()
Expand Down
31 changes: 17 additions & 14 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use std::{
};

use futures::StreamExt;
use tokio::{sync::mpsc, time::Duration};
use tokio::{
sync::mpsc::{self, error::TryRecvError},
time::Duration,
};
use tokio_util::sync::CancellationToken;
use tracing::{info, info_span, warn};

Expand Down Expand Up @@ -393,19 +396,19 @@ where
fn after_high(&mut self, out: &mut Vec<F>, state: &mut ActorState) {
self.fairness.after_high();

if self.fairness.should_yield()
&& let Some(rx) = &mut self.low_rx
{
match rx.try_recv() {
Ok(mut frame) => {
self.hooks.before_send(&mut frame, &mut self.ctx);
out.push(frame);
self.after_low();
}
Err(mpsc::error::TryRecvError::Empty) => {}
Err(mpsc::error::TryRecvError::Disconnected) => {
self.low_rx = None;
state.mark_closed();
if self.fairness.should_yield() {
let res = self.low_rx.as_mut().map(mpsc::Receiver::try_recv);
if let Some(res) = res {
match res {
Ok(mut frame) => {
self.hooks.before_send(&mut frame, &mut self.ctx);
out.push(frame);
self.after_low();
}
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Disconnected) => {
Self::handle_closed_receiver(&mut self.low_rx, state);
}
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/frame/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ fn bytes_to_u64_ok(
#[case] endianness: Endianness,
#[case] expected: u64,
) {
assert_eq!(bytes_to_u64(&bytes, size, endianness).unwrap(), expected);
assert_eq!(
bytes_to_u64(&bytes, size, endianness).expect("failed to convert"),
expected
);
}

#[rstest]
Expand All @@ -42,7 +45,7 @@ fn u64_to_bytes_ok(
#[case] expected: Vec<u8>,
) {
let mut buf = [0u8; 8];
let written = u64_to_bytes(value, size, endianness, &mut buf).unwrap();
let written = u64_to_bytes(value, size, endianness, &mut buf).expect("failed to encode u64");
assert_eq!(written, size);
assert_eq!(&buf[..written], expected.as_slice());
}
Expand Down
13 changes: 9 additions & 4 deletions src/push.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,12 @@ impl<F: FrameLike> PushQueues<F> {
high_capacity: usize,
low_capacity: usize,
) -> (Self, PushHandle<F>) {
Self::bounded_with_rate_dlq(high_capacity, low_capacity, None, None).unwrap()
// `bounded_with_rate_dlq` only fails when given an invalid rate. Passing
// `None` disables rate limiting entirely so the call is infallible. The
// debug assertion guards against future regressions.
let result = Self::bounded_with_rate_dlq(high_capacity, low_capacity, None, None);
debug_assert!(result.is_ok(), "bounded_no_rate_limit should not fail");
result.expect("bounded_no_rate_limit should not fail")
}

/// Create queues with a custom rate limit in pushes per second.
Expand Down Expand Up @@ -382,9 +387,9 @@ impl<F: FrameLike> PushQueues<F> {
rate: Option<usize>,
dlq: Option<mpsc::Sender<F>>,
) -> Result<(Self, PushHandle<F>), PushConfigError> {
if let Some(r) = rate
&& (r == 0 || r > MAX_PUSH_RATE)
{
if let Some(r) = rate.filter(|r| *r == 0 || *r > MAX_PUSH_RATE) {
// Reject unsupported rates early to avoid building queues that cannot
// be used. The bounds prevent runaway resource consumption.
return Err(PushConfigError::InvalidRate(r));
}
let (high_tx, high_rx) = mpsc::channel(high_capacity);
Expand Down
85 changes: 54 additions & 31 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,15 @@ async fn process_stream<F, T>(
let peer_addr = stream.peer_addr().ok();
match read_preamble::<_, T>(&mut stream).await {
Ok((preamble, leftover)) => {
if let Some(handler) = on_success.as_ref()
&& let Err(e) = handler(&preamble, &mut stream).await
{
tracing::error!(error = ?e, ?peer_addr, "preamble callback error");
if let Some(handler) = on_success.as_ref() {
match handler(&preamble, &mut stream).await {
Ok(()) => {}
Err(e) => {
// Log and continue processing if the callback fails; connection
// handling should not halt due to diagnostic hooks.
tracing::error!(error = ?e, ?peer_addr, "preamble callback error");
}
}
}
let stream = RewindStream::new(leftover, stream);
// Hand the connection to the application for processing.
Expand Down Expand Up @@ -556,14 +561,6 @@ mod tests {
message: String,
}

/// Test helper preamble carrying no data.
#[derive(Debug, Clone, PartialEq, Encode, Decode)]
#[expect(
dead_code,
reason = "used only in doctest to illustrate an empty preamble"
)]
struct EmptyPreamble;

#[fixture]
fn factory() -> impl Fn() -> WireframeApp + Send + Sync + Clone + 'static {
|| WireframeApp::default()
Expand All @@ -572,8 +569,11 @@ mod tests {
#[fixture]
fn free_port() -> SocketAddr {
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let listener = std::net::TcpListener::bind(addr).unwrap();
listener.local_addr().unwrap()
let listener =
std::net::TcpListener::bind(addr).expect("failed to bind free port listener");
listener
.local_addr()
.expect("failed to read free port listener address")
}

fn bind_server<F>(factory: F, addr: SocketAddr) -> WireframeServer<F>
Expand Down Expand Up @@ -649,7 +649,9 @@ mod tests {
free_port: SocketAddr,
) {
let server = bind_server(factory, free_port);
let bound_addr = server.local_addr().unwrap();
let bound_addr = server
.local_addr()
.expect("bound server should return local address");
assert_eq!(bound_addr.ip(), free_port.ip());
}

Expand Down Expand Up @@ -681,7 +683,10 @@ mod tests {
let server = bind_server(factory, free_port);
let local_addr = server.local_addr();
assert!(local_addr.is_some());
assert_eq!(local_addr.unwrap().ip(), free_port.ip());
assert_eq!(
local_addr.expect("local address missing").ip(),
free_port.ip()
);
}

#[rstest]
Expand Down Expand Up @@ -800,7 +805,7 @@ mod tests {
.await;

assert!(result.is_ok());
assert!(result.unwrap().is_ok());
assert!(result.expect("server run timed out").is_ok());
}

#[rstest]
Expand All @@ -827,7 +832,7 @@ mod tests {
let elapsed = start.elapsed();

assert!(result.is_ok());
assert!(result.unwrap().is_ok());
assert!(result.expect("server run timed out").is_ok());
assert!(elapsed < Duration::from_millis(500));
}

Expand Down Expand Up @@ -862,7 +867,7 @@ mod tests {
.await;

assert!(result.is_ok());
assert!(result.unwrap().is_ok());
assert!(result.expect("server run timed out").is_ok());
}

#[rstest]
Expand Down Expand Up @@ -903,7 +908,11 @@ mod tests {
) {
let token = CancellationToken::new();
let tracker = TaskTracker::new();
let listener = Arc::new(TcpListener::bind("127.0.0.1:0").await.unwrap());
let listener = Arc::new(
TcpListener::bind("127.0.0.1:0")
.await
.expect("failed to bind test listener"),
);

tracker.spawn(accept_loop::<_, ()>(
listener,
Expand Down Expand Up @@ -944,15 +953,18 @@ mod tests {
let addr1 = free_port;
let addr2 = {
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let listener = std::net::TcpListener::bind(addr).unwrap();
listener.local_addr().unwrap()
let listener =
std::net::TcpListener::bind(addr).expect("failed to bind second listener");
listener
.local_addr()
.expect("failed to get second listener address")
};

let server = server.bind(addr1).expect("Failed to bind first address");
let first_local_addr = server.local_addr().unwrap();
let first_local_addr = server.local_addr().expect("first bound address missing");

let server = server.bind(addr2).expect("Failed to bind second address");
let second_local_addr = server.local_addr().unwrap();
let second_local_addr = server.local_addr().expect("second bound address missing");

assert_ne!(first_local_addr.port(), second_local_addr.port());
assert_eq!(second_local_addr.ip(), addr2.ip());
Expand Down Expand Up @@ -1032,13 +1044,19 @@ mod tests {
let app_factory = move || {
factory()
.on_connection_setup(|| async { panic!("boom") })
.unwrap()
.expect("failed to install panic setup callback")
};
let server = WireframeServer::new(app_factory)
.workers(1)
.bind("127.0.0.1:0".parse().unwrap())
.bind(
"127.0.0.1:0"
.parse()
.expect("hard-coded socket address must be valid"),
)
.expect("bind");
let addr = server.local_addr().unwrap();
let addr = server
.local_addr()
.expect("failed to retrieve server address");

let (tx, rx) = oneshot::channel();
let handle = tokio::spawn(async move {
Expand All @@ -1047,22 +1065,27 @@ mod tests {
let _ = rx.await;
})
.await
.unwrap();
.expect("server run failed");
});

let first = TcpStream::connect(addr)
.await
.expect("first connection should succeed");
let peer_addr = first.local_addr().expect("first connection peer address");
first.writable().await.unwrap();
first.try_write(&[0; 8]).unwrap();
first
.writable()
.await
.expect("connection not writable after connect");
first
.try_write(&[0; 8])
.expect("failed to write dummy bytes");
drop(first);
TcpStream::connect(addr)
.await
.expect("second connection should succeed after panic");

let _ = tx.send(());
handle.await.unwrap();
handle.await.expect("server join error");

tokio::task::yield_now().await;

Expand Down
5 changes: 3 additions & 2 deletions tests/app_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ fn shared_state_extractor_returns_data(
mut empty_payload: Payload<'static>,
) {
request.insert_state(5u32);
let extracted = SharedState::<u32>::from_message_request(&request, &mut empty_payload).unwrap();
let extracted = SharedState::<u32>::from_message_request(&request, &mut empty_payload)
.expect("failed to extract shared state");
assert_eq!(*extracted, 5);
}

Expand All @@ -42,6 +43,6 @@ fn missing_shared_state_returns_error(
) {
let err = SharedState::<u32>::from_message_request(&request, &mut empty_payload)
.err()
.unwrap();
.expect("missing state error expected");
assert!(matches!(err, ExtractError::MissingState(_)));
}
2 changes: 1 addition & 1 deletion tests/async_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ async fn async_stream_frames_processed_in_order() {

let mut actor = ConnectionActor::new(queues, handle, Some(stream), shutdown);
let mut out = Vec::new();
actor.run(&mut out).await.unwrap();
actor.run(&mut out).await.expect("actor run failed");
assert_eq!(out, vec![0, 1, 2]);
}
Loading