Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use bytes::BytesMut;
use tokio::io::{self, AsyncWrite, AsyncWriteExt};

use crate::{
frame::{FrameProcessor, LengthPrefixedProcessor},
frame::{FrameProcessor, LengthFormat, LengthPrefixedProcessor},
message::Message,
serializer::{BincodeSerializer, Serializer},
};
Expand Down Expand Up @@ -147,10 +147,10 @@ where
S: Serializer + Default,
C: Send + 'static,
{
frame_processor: Box::new(LengthPrefixedProcessor::default()),
///
/// Initialises empty routes, services, middleware, and application data. Sets the
/// default frame processor and serializer, with no connection lifecycle hooks.
/// Initialises empty routes, services, middleware, and application data.
/// Sets the default frame processor and serializer, with no connection
/// lifecycle hooks.
fn default() -> Self {
Self {
routes: HashMap::new(),
Expand Down
88 changes: 61 additions & 27 deletions src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,69 @@ pub struct LengthFormat {
}

impl LengthFormat {
if bytes.len() < self.bytes {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"length prefix truncated",
));
}
if !matches!(self.bytes, 1 | 2 | 4 | 8) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"unsupported length prefix size",
));
}
/// Creates a new `LengthFormat` with the specified number of bytes and
/// endianness for the length prefix.
///
/// # Parameters
/// - `bytes`: The number of bytes used for the length prefix.
/// - `endianness`: The byte order for encoding and decoding the length prefix.
///
/// # Returns
/// A `LengthFormat` configured with the given size and endianness.
#[must_use]
pub const fn new(bytes: usize, endianness: Endianness) -> Self { Self { bytes, endianness } }

let mut slice = &bytes[..self.bytes];
let len = match self.endianness {
Endianness::Big => slice.get_uint(self.bytes),
Endianness::Little => slice.get_uint_le(self.bytes),
if !matches!(self.bytes, 1 | 2 | 4 | 8) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"unsupported length prefix size",
));
}
/// Creates a `LengthFormat` for a 2-byte big-endian length prefix.
#[must_use]
pub const fn u16_be() -> Self { Self::new(2, Endianness::Big) }

/// Creates a `LengthFormat` for a 2-byte little-endian length prefix.
#[must_use]
pub const fn u16_le() -> Self { Self::new(2, Endianness::Little) }

/// Creates a `LengthFormat` for a 4-byte big-endian length prefix.
#[must_use]
pub const fn u32_be() -> Self { Self::new(4, Endianness::Big) }

/// Creates a `LengthFormat` for a 4-byte little-endian length prefix.
#[must_use]
pub const fn u32_le() -> Self { Self::new(4, Endianness::Little) }

let len_u64 = u64::try_from(len)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "frame too large"))?;
match self.endianness {
Endianness::Big => dst.put_uint(len_u64, self.bytes),
Endianness::Little => dst.put_uint_le(len_u64, self.bytes),
/// Reads a length prefix from a byte slice according to the configured prefix size and
/// endianness.
///
/// # Parameters
/// - `bytes`: The byte slice containing the length prefix. Must be at least as long as the
/// configured prefix size.
///
/// # Returns
/// The decoded length as a `usize` if successful.
///
/// # Errors
/// Returns an error if the prefix size is unsupported or if the decoded length does not fit in
/// a `usize`.
fn read_len(&self, bytes: &[u8]) -> io::Result<usize> {
let len = match (self.bytes, self.endianness) {
(1, _) => u64::from(u8::from_ne_bytes([bytes[0]])),
(2, Endianness::Big) => u64::from(u16::from_be_bytes([bytes[0], bytes[1]])),
(2, Endianness::Little) => u64::from(u16::from_le_bytes([bytes[0], bytes[1]])),
(4, Endianness::Big) => {
u64::from(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
(4, Endianness::Little) => {
u64::from(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
(8, Endianness::Big) => u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]),
(8, Endianness::Little) => u64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]),
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"unsupported length prefix size",
));
}
};
usize::try_from(len).map_err(|_| io::Error::other("frame too large"))
Expand Down
61 changes: 24 additions & 37 deletions tests/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ async fn send_response_encodes_and_frames() {
}

#[tokio::test]
/// Tests that decoding with an incomplete length prefix header returns `None` and does not consume any bytes from the buffer.
/// Tests that decoding with an incomplete length prefix header returns `None` and does not consume
/// any bytes from the buffer.
///
/// This ensures that the decoder waits for the full header before attempting to decode a frame.
async fn length_prefixed_decode_requires_complete_header() {
Expand Down Expand Up @@ -94,6 +95,16 @@ impl tokio::io::AsyncWrite for FailingWriter {
_: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}

fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}

#[rstest]
#[case(LengthFormat::u16_be(), vec![1, 2, 3, 4], vec![0x00, 0x04])]
#[case(LengthFormat::u32_le(), vec![9, 8, 7], vec![3, 0, 0, 0])]
Expand All @@ -102,15 +113,27 @@ fn custom_length_roundtrip(
#[case] frame: Vec<u8>,
#[case] prefix: Vec<u8>,
) {
let processor = LengthPrefixedProcessor::new(fmt);
let mut buf = BytesMut::new();
processor.encode(&frame, &mut buf).unwrap();
assert_eq!(&buf[..prefix.len()], &prefix[..]);
let decoded = processor.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, frame);
}

#[tokio::test]
async fn send_response_propagates_write_error() {
let app = WireframeApp::new()
.unwrap()
.frame_processor(LengthPrefixedProcessor::default());

let mut writer = FailingWriter;
let err = app
.send_response(&mut writer, &TestResp(3))
.await
.expect_err("expected error");
assert!(matches!(err, wireframe::app::SendError::Io(_)));
}

#[test]
fn encode_fails_for_unsupported_prefix_size() {
Expand All @@ -131,11 +154,6 @@ fn decode_fails_for_unsupported_prefix_size() {
let err = processor.decode(&mut buf).expect_err("expected error");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
.send_response(&mut writer, &TestResp(3))
.await
.expect_err("expected error");
assert!(matches!(err, wireframe::app::SendError::Io(_)));
}

#[tokio::test]
/// Tests that `send_response` returns a serialization error when encoding fails.
Expand All @@ -150,34 +168,3 @@ async fn send_response_returns_encode_error() {
.expect_err("expected error");
assert!(matches!(err, wireframe::app::SendError::Serialize(_)));
}

#[test]
/// Tests roundtrip encoding and decoding of a frame using a two-byte big-endian length prefix.
///
/// Verifies that a frame encoded with a `LengthPrefixedProcessor` configured for a 2-byte
/// big-endian length format can be correctly decoded back to its original contents.
fn custom_two_byte_big_endian_roundtrip() {
let fmt = LengthFormat::u16_be();
let processor = LengthPrefixedProcessor::new(fmt);
let frame = vec![1, 2, 3, 4];
let mut buf = BytesMut::new();
processor.encode(&frame, &mut buf).unwrap();
let decoded = processor.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, frame);
}

#[test]
/// Tests roundtrip encoding and decoding of a frame using a four-byte little-endian length prefix.
///
/// Verifies that the encoded buffer contains the correct little-endian length prefix and that
/// decoding restores the original frame.
fn custom_four_byte_little_endian_roundtrip() {
let fmt = LengthFormat::u32_le();
let processor = LengthPrefixedProcessor::new(fmt);
let frame = vec![9, 8, 7];
let mut buf = BytesMut::new();
processor.encode(&frame, &mut buf).unwrap();
assert_eq!(&buf[..4], u32::try_from(frame.len()).unwrap().to_le_bytes());
let decoded = processor.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, frame);
}