diff --git a/src/app.rs b/src/app.rs index abb674ea..fd7cd6fb 100644 --- a/src/app.rs +++ b/src/app.rs @@ -17,7 +17,7 @@ use bytes::BytesMut; use tokio::io::{self, AsyncWrite, AsyncWriteExt}; use crate::{ - frame::{FrameProcessor, LengthPrefixedProcessor}, + frame::{FrameProcessor, LengthFormat, LengthPrefixedProcessor}, message::Message, serializer::{BincodeSerializer, Serializer}, }; @@ -147,10 +147,10 @@ where S: Serializer + Default, C: Send + 'static, { - frame_processor: Box::new(LengthPrefixedProcessor::default()), /// - /// Initialises empty routes, services, middleware, and application data. Sets the - /// default frame processor and serializer, with no connection lifecycle hooks. + /// Initialises empty routes, services, middleware, and application data. + /// Sets the default frame processor and serializer, with no connection + /// lifecycle hooks. fn default() -> Self { Self { routes: HashMap::new(), diff --git a/src/frame.rs b/src/frame.rs index e8b531f4..97e31dc5 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -25,35 +25,69 @@ pub struct LengthFormat { } impl LengthFormat { - if bytes.len() < self.bytes { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "length prefix truncated", - )); - } - if !matches!(self.bytes, 1 | 2 | 4 | 8) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "unsupported length prefix size", - )); - } + /// Creates a new `LengthFormat` with the specified number of bytes and + /// endianness for the length prefix. + /// + /// # Parameters + /// - `bytes`: The number of bytes used for the length prefix. + /// - `endianness`: The byte order for encoding and decoding the length prefix. + /// + /// # Returns + /// A `LengthFormat` configured with the given size and endianness. + #[must_use] + pub const fn new(bytes: usize, endianness: Endianness) -> Self { Self { bytes, endianness } } - let mut slice = &bytes[..self.bytes]; - let len = match self.endianness { - Endianness::Big => slice.get_uint(self.bytes), - Endianness::Little => slice.get_uint_le(self.bytes), - if !matches!(self.bytes, 1 | 2 | 4 | 8) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "unsupported length prefix size", - )); - } + /// Creates a `LengthFormat` for a 2-byte big-endian length prefix. + #[must_use] + pub const fn u16_be() -> Self { Self::new(2, Endianness::Big) } + + /// Creates a `LengthFormat` for a 2-byte little-endian length prefix. + #[must_use] + pub const fn u16_le() -> Self { Self::new(2, Endianness::Little) } + + /// Creates a `LengthFormat` for a 4-byte big-endian length prefix. + #[must_use] + pub const fn u32_be() -> Self { Self::new(4, Endianness::Big) } + + /// Creates a `LengthFormat` for a 4-byte little-endian length prefix. + #[must_use] + pub const fn u32_le() -> Self { Self::new(4, Endianness::Little) } - let len_u64 = u64::try_from(len) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "frame too large"))?; - match self.endianness { - Endianness::Big => dst.put_uint(len_u64, self.bytes), - Endianness::Little => dst.put_uint_le(len_u64, self.bytes), + /// Reads a length prefix from a byte slice according to the configured prefix size and + /// endianness. + /// + /// # Parameters + /// - `bytes`: The byte slice containing the length prefix. Must be at least as long as the + /// configured prefix size. + /// + /// # Returns + /// The decoded length as a `usize` if successful. + /// + /// # Errors + /// Returns an error if the prefix size is unsupported or if the decoded length does not fit in + /// a `usize`. + fn read_len(&self, bytes: &[u8]) -> io::Result { + let len = match (self.bytes, self.endianness) { + (1, _) => u64::from(u8::from_ne_bytes([bytes[0]])), + (2, Endianness::Big) => u64::from(u16::from_be_bytes([bytes[0], bytes[1]])), + (2, Endianness::Little) => u64::from(u16::from_le_bytes([bytes[0], bytes[1]])), + (4, Endianness::Big) => { + u64::from(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) + } + (4, Endianness::Little) => { + u64::from(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) + } + (8, Endianness::Big) => u64::from_be_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]), + (8, Endianness::Little) => u64::from_le_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]), + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "unsupported length prefix size", + )); } }; usize::try_from(len).map_err(|_| io::Error::other("frame too large")) diff --git a/tests/response.rs b/tests/response.rs index 7d4fa0dc..069334d0 100644 --- a/tests/response.rs +++ b/tests/response.rs @@ -55,7 +55,8 @@ async fn send_response_encodes_and_frames() { } #[tokio::test] -/// Tests that decoding with an incomplete length prefix header returns `None` and does not consume any bytes from the buffer. +/// Tests that decoding with an incomplete length prefix header returns `None` and does not consume +/// any bytes from the buffer. /// /// This ensures that the decoder waits for the full header before attempting to decode a frame. async fn length_prefixed_decode_requires_complete_header() { @@ -94,6 +95,16 @@ impl tokio::io::AsyncWrite for FailingWriter { _: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + #[rstest] #[case(LengthFormat::u16_be(), vec![1, 2, 3, 4], vec![0x00, 0x04])] #[case(LengthFormat::u32_le(), vec![9, 8, 7], vec![3, 0, 0, 0])] @@ -102,8 +113,15 @@ fn custom_length_roundtrip( #[case] frame: Vec, #[case] prefix: Vec, ) { + let processor = LengthPrefixedProcessor::new(fmt); + let mut buf = BytesMut::new(); + processor.encode(&frame, &mut buf).unwrap(); assert_eq!(&buf[..prefix.len()], &prefix[..]); + let decoded = processor.decode(&mut buf).unwrap().unwrap(); + assert_eq!(decoded, frame); +} +#[tokio::test] async fn send_response_propagates_write_error() { let app = WireframeApp::new() .unwrap() @@ -111,6 +129,11 @@ async fn send_response_propagates_write_error() { let mut writer = FailingWriter; let err = app + .send_response(&mut writer, &TestResp(3)) + .await + .expect_err("expected error"); + assert!(matches!(err, wireframe::app::SendError::Io(_))); +} #[test] fn encode_fails_for_unsupported_prefix_size() { @@ -131,11 +154,6 @@ fn decode_fails_for_unsupported_prefix_size() { let err = processor.decode(&mut buf).expect_err("expected error"); assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); } - .send_response(&mut writer, &TestResp(3)) - .await - .expect_err("expected error"); - assert!(matches!(err, wireframe::app::SendError::Io(_))); -} #[tokio::test] /// Tests that `send_response` returns a serialization error when encoding fails. @@ -150,34 +168,3 @@ async fn send_response_returns_encode_error() { .expect_err("expected error"); assert!(matches!(err, wireframe::app::SendError::Serialize(_))); } - -#[test] -/// Tests roundtrip encoding and decoding of a frame using a two-byte big-endian length prefix. -/// -/// Verifies that a frame encoded with a `LengthPrefixedProcessor` configured for a 2-byte -/// big-endian length format can be correctly decoded back to its original contents. -fn custom_two_byte_big_endian_roundtrip() { - let fmt = LengthFormat::u16_be(); - let processor = LengthPrefixedProcessor::new(fmt); - let frame = vec![1, 2, 3, 4]; - let mut buf = BytesMut::new(); - processor.encode(&frame, &mut buf).unwrap(); - let decoded = processor.decode(&mut buf).unwrap().unwrap(); - assert_eq!(decoded, frame); -} - -#[test] -/// Tests roundtrip encoding and decoding of a frame using a four-byte little-endian length prefix. -/// -/// Verifies that the encoded buffer contains the correct little-endian length prefix and that -/// decoding restores the original frame. -fn custom_four_byte_little_endian_roundtrip() { - let fmt = LengthFormat::u32_le(); - let processor = LengthPrefixedProcessor::new(fmt); - let frame = vec![9, 8, 7]; - let mut buf = BytesMut::new(); - processor.encode(&frame, &mut buf).unwrap(); - assert_eq!(&buf[..4], u32::try_from(frame.len()).unwrap().to_le_bytes()); - let decoded = processor.decode(&mut buf).unwrap().unwrap(); - assert_eq!(decoded, frame); -}