diff --git a/.gitignore b/.gitignore index 70487d22..4b79068f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ node_modules test-ledger dist +.idea diff --git a/pod/src/error.rs b/pod/src/error.rs index 2d83a231..03d5b70b 100644 --- a/pod/src/error.rs +++ b/pod/src/error.rs @@ -1,5 +1,8 @@ //! Error types -use solana_program_error::{ProgramError, ToStr}; +use { + solana_program_error::{ProgramError, ToStr}, + std::num::TryFromIntError, +}; /// Errors that may be returned by the spl-pod library. #[repr(u32)] @@ -22,6 +25,9 @@ pub enum PodSliceError { /// Provided byte buffer too large for expected type #[error("Provided byte buffer too large for expected type")] BufferTooLarge, + /// An integer conversion failed because the value was out of range for the target type + #[error("An integer conversion failed because the value was out of range for the target type")] + ValueOutOfRange, } impl From for ProgramError { @@ -36,6 +42,13 @@ impl ToStr for PodSliceError { PodSliceError::CalculationFailure => "Error in checked math operation", PodSliceError::BufferTooSmall => "Provided byte buffer too small for expected type", PodSliceError::BufferTooLarge => "Provided byte buffer too large for expected type", + PodSliceError::ValueOutOfRange => "An integer conversion failed because the value was out of range for the target type" } } } + +impl From for PodSliceError { + fn from(_: TryFromIntError) -> Self { + PodSliceError::ValueOutOfRange + } +} diff --git a/pod/src/lib.rs b/pod/src/lib.rs index e1aa65fd..b9a26da1 100644 --- a/pod/src/lib.rs +++ b/pod/src/lib.rs @@ -2,8 +2,10 @@ pub mod bytemuck; pub mod error; +pub mod list; pub mod option; pub mod optional_keys; +pub mod pod_length; pub mod primitives; pub mod slice; diff --git a/pod/src/list/list_trait.rs b/pod/src/list/list_trait.rs new file mode 100644 index 00000000..f3d27b70 --- /dev/null +++ b/pod/src/list/list_trait.rs @@ -0,0 +1,44 @@ +use { + crate::{list::ListView, pod_length::PodLength}, + bytemuck::Pod, + solana_program_error::ProgramError, + std::slice::Iter, +}; + +/// A trait to abstract the shared, read-only behavior +/// between `ListViewReadOnly` and `ListViewMut`. +pub trait List { + /// The type of the items stored in the list. + type Item: Pod; + /// Length prefix type used (`PodU16`, `PodU32`, …). + type Length: PodLength; + + /// Returns the number of items in the list. + fn len(&self) -> usize; + + /// Returns `true` if the list contains no items. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the total number of items that can be stored in the list. + fn capacity(&self) -> usize; + + /// Returns a read-only slice of the items currently in the list. + fn as_slice(&self) -> &[Self::Item]; + + /// Returns a read-only iterator over the list. + fn iter(&self) -> Iter<'_, Self::Item> { + self.as_slice().iter() + } + + /// Returns the number of **bytes currently occupied** by the live elements + fn bytes_used(&self) -> Result { + ListView::::size_of(self.len()) + } + + /// Returns the number of **bytes reserved** by the entire backing buffer. + fn bytes_allocated(&self) -> Result { + ListView::::size_of(self.capacity()) + } +} diff --git a/pod/src/list/list_view.rs b/pod/src/list/list_view.rs new file mode 100644 index 00000000..1a3419ed --- /dev/null +++ b/pod/src/list/list_view.rs @@ -0,0 +1,624 @@ +//! `ListView`, a compact, zero-copy array wrapper. + +use { + crate::{ + bytemuck::{ + pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, + }, + error::PodSliceError, + list::{list_view_mut::ListViewMut, list_view_read_only::ListViewReadOnly}, + pod_length::PodLength, + primitives::PodU32, + }, + bytemuck::Pod, + solana_program_error::ProgramError, + std::{ + marker::PhantomData, + mem::{align_of, size_of}, + ops::Range, + }, +}; + +/// An API for interpreting a raw buffer (`&[u8]`) as a variable-length collection of Pod elements. +/// +/// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `Pod` data that resides in an external, pre-allocated `&[u8]` buffer. +/// It does not own the buffer itself, but acts as a view over it, which can be +/// read-only (`ListViewReadOnly`) or mutable (`ListViewMut`). +/// +/// This is useful in environments where allocations are restricted or expensive, +/// such as Solana programs, allowing for efficient reads and manipulation of +/// dynamic-length data structures. +/// +/// ## Memory Layout +/// +/// The structure assumes the underlying byte buffer is formatted as follows: +/// 1. **Length**: A length field of type `L` at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. +/// Defaults to `PodU32`. The implementation uses padding to ensure that the +/// data is correctly aligned for any `Pod` type. +/// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. +/// 3. **Data**: The remaining part of the buffer, which is treated as a slice +/// of `T` elements. The capacity of the collection is the number of `T` +/// elements that can fit into this data portion. +pub struct ListView(PhantomData<(T, L)>); + +struct Layout { + length_range: Range, + data_range: Range, +} + +impl ListView { + /// Calculate the total byte size for a `ListView` holding `num_items`. + /// This includes the length prefix, padding, and data. + pub fn size_of(num_items: usize) -> Result { + let header_padding = Self::header_padding()?; + size_of::() + .checked_mul(num_items) + .and_then(|curr| curr.checked_add(size_of::())) + .and_then(|curr| curr.checked_add(header_padding)) + .ok_or_else(|| PodSliceError::CalculationFailure.into()) + } + + /// Unpack a read-only buffer into a `ListViewReadOnly` + pub fn unpack(buf: &[u8]) -> Result, ProgramError> { + let layout = Self::calculate_layout(buf.len())?; + + // Slice the buffer to get the length prefix and the data. + // The layout calculation provides the correct ranges, accounting for any + // padding between the length and the data. + // + // buf: [ L L L L | P P | D D D D D D D D ...] + // <-----> <------------------> + // len_bytes data_bytes + let len_bytes = &buf[layout.length_range]; + let data_bytes = &buf[layout.data_range]; + + let length = pod_from_bytes::(len_bytes)?; + let data = pod_slice_from_bytes::(data_bytes)?; + let capacity = data.len(); + + if (*length).into() > capacity { + return Err(PodSliceError::BufferTooSmall.into()); + } + + Ok(ListViewReadOnly { + length, + data, + capacity, + }) + } + + /// Unpack the mutable buffer into a mutable `ListViewMut` + pub fn unpack_mut(buf: &mut [u8]) -> Result, ProgramError> { + let view = Self::build_mut_view(buf)?; + if (*view.length).into() > view.capacity { + return Err(PodSliceError::BufferTooSmall.into()); + } + Ok(view) + } + + /// Initialize a buffer: sets `length = 0` and returns a mutable `ListViewMut`. + pub fn init(buf: &mut [u8]) -> Result, ProgramError> { + let view = Self::build_mut_view(buf)?; + *view.length = L::try_from(0)?; + Ok(view) + } + + /// Internal helper to build a mutable view without validation or initialization. + #[inline] + fn build_mut_view(buf: &mut [u8]) -> Result, ProgramError> { + let layout = Self::calculate_layout(buf.len())?; + + // Split the buffer to get the length prefix and the data. + // buf: [ L L L L | P P | D D D D D D D D ...] + // <---- head ---> <--- tail ---------> + let (header_bytes, data_bytes) = buf.split_at_mut(layout.data_range.start); + // header: [ L L L L | P P ] + // <-----> + // len_bytes + let len_bytes = &mut header_bytes[layout.length_range]; + + // Cast the bytes to typed data + let length = pod_from_bytes_mut::(len_bytes)?; + let data = pod_slice_from_bytes_mut::(data_bytes)?; + let capacity = data.len(); + + Ok(ListViewMut { + length, + data, + capacity, + }) + } + + /// Calculate the byte ranges for the length and data sections of the buffer + #[inline] + fn calculate_layout(buf_len: usize) -> Result { + let len_field_end = size_of::(); + let header_padding = Self::header_padding()?; + let data_start = len_field_end.saturating_add(header_padding); + + if buf_len < data_start { + return Err(PodSliceError::BufferTooSmall.into()); + } + + Ok(Layout { + length_range: 0..len_field_end, + data_range: data_start..buf_len, + }) + } + + /// Calculate the padding required to align the data part of the buffer. + /// + /// The goal is to ensure that the data field `T` starts at a memory offset + /// that is a multiple of its alignment requirement. + #[inline] + fn header_padding() -> Result { + // Enforce that the length prefix type `L` itself does not have alignment requirements + if align_of::() != 1 { + return Err(ProgramError::InvalidArgument); + } + + let length_size = size_of::(); + let data_align = align_of::(); + + // No padding is needed for alignments of 0 or 1 + if data_align == 0 || data_align == 1 { + return Ok(0); + } + + // Find how many bytes `length_size` extends past an alignment boundary + #[allow(clippy::arithmetic_side_effects)] + let remainder = length_size.wrapping_rem(data_align); + + // If already aligned (remainder is 0), no padding is needed. + // Otherwise, calculate the distance to the next alignment boundary. + if remainder == 0 { + Ok(0) + } else { + Ok(data_align.wrapping_sub(remainder)) + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::List, + primitives::{PodU16, PodU32, PodU64}, + }, + bytemuck_derive::{Pod as DerivePod, Zeroable}, + }; + + #[test] + fn test_size_of_no_padding() { + // Case 1: T has align 1, so no padding is ever needed. + // 10 items * 1 byte/item + 4 bytes for length = 14 + assert_eq!(ListView::::size_of(10).unwrap(), 14); + + // Case 2: size_of is a multiple of align_of, so no padding needed. + // T = u32 (size 4, align 4), L = PodU32 (size 4). 4 % 4 == 0. + // 10 items * 4 bytes/item + 4 bytes for length = 44 + assert_eq!(ListView::::size_of(10).unwrap(), 44); + + // Case 3: 0 items. Size should just be size_of + padding. + // Padding is 0 here. + // 0 items * 4 bytes/item + 4 bytes for length = 4 + assert_eq!(ListView::::size_of(0).unwrap(), 4); + } + + #[test] + fn test_size_of_with_padding() { + // Case 1: Padding is required. + // T = u64 (size 8, align 8), L = PodU32 (size 4). + // Padding required to align data to 8 bytes is 4. (4 + 4 = 8) + // (10 items * 8 bytes/item) + 4 bytes for length + 4 bytes for padding = 88 + assert_eq!(ListView::::size_of(10).unwrap(), 88); + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align16(u128); + + // Case 2: Custom struct with high alignment. + // size 16, align 16 + // L = PodU64 (size 8). + // Padding required to align data to 16 bytes is 8. (8 + 8 = 16) + // (10 items * 16 bytes/item) + 8 bytes for length + 8 bytes for padding = 176 + assert_eq!(ListView::::size_of(10).unwrap(), 176); + + // Case 3: 0 items with padding. + // Size should be size_of + padding. + // L = PodU32 (size 4), T = u64 (align 8). Padding is 4. + // Total size = 4 + 4 = 8 + assert_eq!(ListView::::size_of(0).unwrap(), 8); + } + + #[test] + fn test_size_of_overflow() { + // Case 1: Multiplication overflows. + // `size_of::() * usize::MAX` will overflow. + let err = ListView::::size_of(usize::MAX).unwrap_err(); + assert_eq!(err, PodSliceError::CalculationFailure.into()); + + // Case 2: Multiplication does not overflow, but subsequent addition does. + // `size_of::() * usize::MAX` does not overflow, but adding `size_of` will. + let err = ListView::::size_of(usize::MAX).unwrap_err(); + assert_eq!(err, PodSliceError::CalculationFailure.into()); + } + + #[test] + fn test_fails_with_non_aligned_length_type() { + // A custom `PodLength` type with an alignment of 4 + #[repr(C, align(4))] + #[derive(Debug, Copy, Clone, Zeroable, DerivePod)] + struct TestPodU32(u32); + + // Implement the traits for `PodLength` + impl From for usize { + fn from(val: TestPodU32) -> Self { + val.0 as usize + } + } + impl TryFrom for TestPodU32 { + type Error = PodSliceError; + fn try_from(val: usize) -> Result { + Ok(Self(u32::try_from(val)?)) + } + } + + let mut buf = [0u8; 100]; + + let err_size_of = ListView::::size_of(10).unwrap_err(); + assert_eq!(err_size_of, ProgramError::InvalidArgument); + + let err_unpack = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err_unpack, ProgramError::InvalidArgument); + + let err_init = ListView::::init(&mut buf).unwrap_err(); + assert_eq!(err_init, ProgramError::InvalidArgument); + } + + #[test] + fn test_padding_calculation() { + // `u8` has an alignment of 1, so no padding is ever needed. + assert_eq!(ListView::::header_padding().unwrap(), 0); + + // Zero-Sized Types like `()` have size 0 and align 1, requiring no padding. + assert_eq!(ListView::<(), PodU64>::header_padding().unwrap(), 0); + + // When length and data have the same alignment. + assert_eq!(ListView::::header_padding().unwrap(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); + + // When data alignment is smaller than or perfectly divides the length size. + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 % 2 = 0 + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 % 4 = 0 + + // When padding IS needed. + assert_eq!(ListView::::header_padding().unwrap(), 2); // size_of is 2. To align to 4, need 2 bytes. + assert_eq!(ListView::::header_padding().unwrap(), 6); // size_of is 2. To align to 8, need 6 bytes. + assert_eq!(ListView::::header_padding().unwrap(), 4); // size_of is 4. To align to 8, need 4 bytes. + + // Test with custom, higher alignments. + #[repr(C, align(8))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align8(u64); + + // Test against different length types + assert_eq!(ListView::::header_padding().unwrap(), 6); // 2 + 6 = 8 + assert_eq!(ListView::::header_padding().unwrap(), 4); // 4 + 4 = 8 + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 is already aligned + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align16(u128); + + assert_eq!(ListView::::header_padding().unwrap(), 14); // 2 + 14 = 16 + assert_eq!(ListView::::header_padding().unwrap(), 12); // 4 + 12 = 16 + assert_eq!(ListView::::header_padding().unwrap(), 8); // 8 + 8 = 16 + } + + #[test] + fn test_unpack_success_no_padding() { + // T = u32 (align 4), L = PodU32 (size 4, align 4). No padding needed. + let length: u32 = 2; + let capacity: usize = 3; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size; + let items = [100u32, 200u32]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_success_with_padding() { + // T = u64 (align 8), L = PodU32 (size 4, align 4). Needs 4 bytes padding. + let padding = ListView::::header_padding().unwrap(); + assert_eq!(padding, 4); + + let length: u32 = 2; + let capacity: usize = 2; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + padding + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + // Data starts after length and padding + let data_start = len_size + padding; + let items = [100u64, 200u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_success_zero_length() { + let capacity: usize = 5; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = 0u32.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), 0); + assert_eq!(view_ro.capacity(), capacity); + assert!(view_ro.is_empty()); + assert_eq!(view_ro.as_slice(), &[] as &[u32]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), 0); + assert_eq!(view_mut.capacity(), capacity); + assert!(view_mut.is_empty()); + assert_eq!(view_mut.as_slice(), &[] as &[u32]); + } + + #[test] + fn test_unpack_success_full_capacity() { + let length: u64 = 3; + let capacity: usize = 3; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU64 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size; + let items = [1u64, 2u64, 3u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_success_different_length_type() { + // T = u64 (align 8), L = PodU16 (size 2, align 2). Needs 6 bytes padding. + let padding = ListView::::header_padding().unwrap(); + assert_eq!(padding, 6); + + let length: u16 = 1; + let capacity: usize = 1; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + padding + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU16 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size + padding; + let items = [12345u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_fail_buffer_too_small_for_header() { + // T = u64 (align 8), L = PodU32 (size 4). Header size is 8. + let header_size = ListView::::size_of(0).unwrap(); + assert_eq!(header_size, 8); + + // Provide a buffer smaller than the required header + let mut buf = vec![0u8; header_size - 1]; // 7 bytes + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_unpack_fail_declared_length_exceeds_capacity() { + let declared_length: u32 = 4; + let capacity: usize = 3; // buffer can only hold 3 + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + // Write a length that is bigger than capacity + let pod_len: PodU32 = declared_length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_unpack_fail_data_part_not_multiple_of_item_size() { + let len_size = size_of::(); + + // data part is 5 bytes, not a multiple of item_size (4) + let buf_size = len_size + 5; + let mut buf = vec![0u8; buf_size]; + + // bytemuck::try_cast_slice returns an alignment error, which we map to InvalidArgument + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_unpack_empty_buffer() { + let mut buf = []; + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_init_success_no_padding() { + // T = u32 (align 4), L = PodU32 (size 4). No padding needed. + let capacity: usize = 5; + let len_size = size_of::(); + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } + + #[test] + fn test_init_success_with_padding() { + // T = u64 (align 8), L = PodU32 (size 4). Needs 4 bytes padding. + let capacity: usize = 3; + let len_size = size_of::(); + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + // The padding bytes may or may not be zeroed, we don't assert on them. + } + + #[test] + fn test_init_success_zero_capacity() { + // Test initializing a buffer that can only hold the header. + // T = u64 (align 8), L = PodU32 (size 4). Header size is 8. + let buf_size = ListView::::size_of(0).unwrap(); + assert_eq!(buf_size, 8); + let mut buf = vec![0xFFu8; buf_size]; + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let len_size = size_of::(); + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } + + #[test] + fn test_init_fail_buffer_too_small() { + // Header requires 4 bytes (size_of) + let mut buf = vec![0u8; 3]; + let err = ListView::::init(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + // With padding, header requires 8 bytes (4 for len, 4 for pad) + let mut buf_padded = vec![0u8; 7]; + let err_padded = ListView::::init(&mut buf_padded).unwrap_err(); + assert_eq!(err_padded, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_init_success_default_length_type() { + // This test uses the default L=PodU32 length type by omitting it. + // T = u32 (align 4), L = PodU32 (size 4). No padding needed as 4 % 4 == 0. + let capacity = 5; + let len_size = size_of::(); // Default L is PodU32 + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length (a u32) was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } +} diff --git a/pod/src/list/list_view_mut.rs b/pod/src/list/list_view_mut.rs new file mode 100644 index 00000000..0f6847b6 --- /dev/null +++ b/pod/src/list/list_view_mut.rs @@ -0,0 +1,336 @@ +//! `ListViewMut`, a mutable, compact, zero-copy array wrapper. + +use { + crate::{ + error::PodSliceError, list::list_trait::List, pod_length::PodLength, primitives::PodU32, + }, + bytemuck::Pod, + solana_program_error::ProgramError, +}; + +#[derive(Debug)] +pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU32> { + pub(crate) length: &'data mut L, + pub(crate) data: &'data mut [T], + pub(crate) capacity: usize, +} + +impl ListViewMut<'_, T, L> { + /// Add another item to the slice + pub fn push(&mut self, item: T) -> Result<(), ProgramError> { + let length = (*self.length).into(); + if length >= self.capacity { + Err(PodSliceError::BufferTooSmall.into()) + } else { + self.data[length] = item; + *self.length = L::try_from(length.saturating_add(1))?; + Ok(()) + } + } + + /// Remove and return the element at `index`, shifting all later + /// elements one position to the left. + pub fn remove(&mut self, index: usize) -> Result { + let len = (*self.length).into(); + if index >= len { + return Err(ProgramError::InvalidArgument); + } + + let removed_item = self.data[index]; + + // Move the tail left by one + let tail_start = index + .checked_add(1) + .ok_or(ProgramError::ArithmeticOverflow)?; + self.data.copy_within(tail_start..len, index); + + // Store the new length (len - 1) + let new_len = len.checked_sub(1).unwrap(); + *self.length = L::try_from(new_len)?; + + Ok(removed_item) + } + + /// Returns a mutable iterator over the current elements + pub fn iter_mut(&mut self) -> std::slice::IterMut { + let len = (*self.length).into(); + self.data[..len].iter_mut() + } +} + +impl List for ListViewMut<'_, T, L> { + type Item = T; + type Length = L; + + fn len(&self) -> usize { + (*self.length).into() + } + + fn capacity(&self) -> usize { + self.capacity + } + + fn as_slice(&self) -> &[Self::Item] { + &self.data[..self.len()] + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::{List, ListView}, + primitives::{PodU16, PodU32, PodU64}, + }, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestStruct { + a: u64, + b: u32, + _padding: [u8; 4], + } + + impl TestStruct { + fn new(a: u64, b: u32) -> Self { + Self { + a, + b, + _padding: [0; 4], + } + } + } + + fn init_view_mut( + buffer: &mut Vec, + capacity: usize, + ) -> ListViewMut { + let size = ListView::::size_of(capacity).unwrap(); + buffer.resize(size, 0); + ListView::::init(buffer).unwrap() + } + + #[test] + fn test_push() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + assert_eq!(view.len(), 0); + assert!(view.is_empty()); + assert_eq!(view.capacity(), 3); + + // Push first item + let item1 = TestStruct::new(1, 10); + view.push(item1).unwrap(); + assert_eq!(view.len(), 1); + assert!(!view.is_empty()); + assert_eq!(view.as_slice(), &[item1]); + + // Push second item + let item2 = TestStruct::new(2, 20); + view.push(item2).unwrap(); + assert_eq!(view.len(), 2); + assert_eq!(view.as_slice(), &[item1, item2]); + + // Push third item to fill capacity + let item3 = TestStruct::new(3, 30); + view.push(item3).unwrap(); + assert_eq!(view.len(), 3); + assert_eq!(view.as_slice(), &[item1, item2, item3]); + + // Try to push beyond capacity + let item4 = TestStruct::new(4, 40); + let err = view.push(item4).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + // Ensure state is unchanged + assert_eq!(view.len(), 3); + assert_eq!(view.as_slice(), &[item1, item2, item3]); + } + + #[test] + fn test_remove() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 4); + + let item1 = TestStruct::new(1, 10); + let item2 = TestStruct::new(2, 20); + let item3 = TestStruct::new(3, 30); + let item4 = TestStruct::new(4, 40); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + view.push(item4).unwrap(); + + assert_eq!(view.len(), 4); + assert_eq!(view.as_slice(), &[item1, item2, item3, item4]); + + // Remove from the middle + let removed = view.remove(1).unwrap(); + assert_eq!(removed, item2); + assert_eq!(view.len(), 3); + assert_eq!(view.as_slice(), &[item1, item3, item4]); + + // Remove from the end + let removed = view.remove(2).unwrap(); + assert_eq!(removed, item4); + assert_eq!(view.len(), 2); + assert_eq!(view.as_slice(), &[item1, item3]); + + // Remove from the start + let removed = view.remove(0).unwrap(); + assert_eq!(removed, item1); + assert_eq!(view.len(), 1); + assert_eq!(view.as_slice(), &[item3]); + + // Remove the last element + let removed = view.remove(0).unwrap(); + assert_eq!(removed, item3); + assert_eq!(view.len(), 0); + assert!(view.is_empty()); + assert_eq!(view.as_slice(), &[]); + } + + #[test] + fn test_remove_out_of_bounds() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 2); + + view.push(TestStruct::new(1, 10)).unwrap(); + view.push(TestStruct::new(2, 20)).unwrap(); + + // Try to remove at index == len + let err = view.remove(2).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + assert_eq!(view.len(), 2); // Unchanged + + // Try to remove at index > len + let err = view.remove(100).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + assert_eq!(view.len(), 2); // Unchanged + + // Empty the view + view.remove(1).unwrap(); + view.remove(0).unwrap(); + assert!(view.is_empty()); + + // Try to remove from empty view + let err = view.remove(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_iter_mut() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 4); + + let item1 = TestStruct::new(1, 10); + let item2 = TestStruct::new(2, 20); + let item3 = TestStruct::new(3, 30); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + + assert_eq!(view.len(), 3); + assert_eq!(view.capacity(), 4); + + // Modify items using iter_mut + for item in view.iter_mut() { + item.a *= 10; + } + + let expected_item1 = TestStruct::new(10, 10); + let expected_item2 = TestStruct::new(20, 20); + let expected_item3 = TestStruct::new(30, 30); + + // Check that the underlying data is modified + assert_eq!(view.len(), 3); + assert_eq!( + view.as_slice(), + &[expected_item1, expected_item2, expected_item3] + ); + + // Check that iter_mut only iterates over `len` items, not `capacity` + assert_eq!(view.iter_mut().count(), 3); + } + + #[test] + fn test_iter_mut_empty() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 5); + + let mut count = 0; + for _ in view.iter_mut() { + count += 1; + } + assert_eq!(count, 0); + assert_eq!(view.iter_mut().next(), None); + } + + #[test] + fn test_zero_capacity() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 0); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + + let err = view.push(TestStruct::new(1, 1)).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = view.remove(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_default_length_type() { + let capacity = 2; + let mut buffer = vec![]; + let size = ListView::::size_of(capacity).unwrap(); + buffer.resize(size, 0); + + // Initialize the view *without* specifying L. The compiler uses the default. + let view = ListView::::init(&mut buffer).unwrap(); + + // Check that the capacity is correct for a PodU64 length. + assert_eq!(view.capacity(), capacity); + assert_eq!(view.len(), 0); + + // Verify the size of the length field. + assert_eq!(size_of_val(view.length), size_of::()); + } + + #[test] + fn test_bytes_used_and_allocated_mut() { + // capacity 3, start empty + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + // Empty view + assert_eq!( + view.bytes_used().unwrap(), + ListView::::size_of(0).unwrap() + ); + assert_eq!( + view.bytes_allocated().unwrap(), + ListView::::size_of(view.capacity()).unwrap() + ); + + // After pushing elements + view.push(TestStruct::new(1, 2)).unwrap(); + view.push(TestStruct::new(3, 4)).unwrap(); + view.push(TestStruct::new(5, 6)).unwrap(); + assert_eq!( + view.bytes_used().unwrap(), + ListView::::size_of(3).unwrap() + ); + assert_eq!( + view.bytes_allocated().unwrap(), + ListView::::size_of(view.capacity()).unwrap() + ); + } +} diff --git a/pod/src/list/list_view_read_only.rs b/pod/src/list/list_view_read_only.rs new file mode 100644 index 00000000..6e392544 --- /dev/null +++ b/pod/src/list/list_view_read_only.rs @@ -0,0 +1,176 @@ +//! `ListViewReadOnly`, a read-only, compact, zero-copy array wrapper. + +use { + crate::{list::list_trait::List, pod_length::PodLength, primitives::PodU32}, + bytemuck::Pod, +}; + +#[derive(Debug)] +pub struct ListViewReadOnly<'data, T: Pod, L: PodLength = PodU32> { + pub(crate) length: &'data L, + pub(crate) data: &'data [T], + pub(crate) capacity: usize, +} + +impl List for ListViewReadOnly<'_, T, L> { + type Item = T; + type Length = L; + + fn len(&self) -> usize { + (*self.length).into() + } + + fn capacity(&self) -> usize { + self.capacity + } + + fn as_slice(&self) -> &[Self::Item] { + &self.data[..self.len()] + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::ListView, + pod_length::PodLength, + primitives::{PodU32, PodU64}, + }, + bytemuck_derive::{Pod as DerivePod, Zeroable}, + std::mem::size_of, + }; + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone, Debug, PartialEq)] + struct TestStruct(u128); + + /// Helper to build a byte buffer that conforms to the `ListView` layout. + fn build_test_buffer( + length: usize, + capacity: usize, + items: &[T], + ) -> Vec { + let size = ListView::::size_of(capacity).unwrap(); + let mut buffer = vec![0u8; size]; + + // Write the length prefix + let pod_len = L::try_from(length).unwrap(); + let len_bytes = bytemuck::bytes_of(&pod_len); + buffer[0..size_of::()].copy_from_slice(len_bytes); + + // Write the data items, accounting for padding + if !items.is_empty() { + let data_start = ListView::::size_of(0).unwrap(); + let items_bytes = bytemuck::cast_slice(items); + buffer[data_start..data_start.saturating_add(items_bytes.len())] + .copy_from_slice(items_bytes); + } + + buffer + } + + #[test] + fn test_len_and_capacity() { + let items = [10u32, 20, 30]; + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.len(), 3); + assert_eq!(view.capacity(), 5); + } + + #[test] + fn test_as_slice() { + let items = [10u32, 20, 30]; + // Buffer has capacity for 5, but we only use 3. + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // `as_slice()` should only return the first `len` items. + assert_eq!(view.as_slice(), &items[..]); + assert_eq!(view.as_slice().len(), view.len()); + } + + #[test] + fn test_is_empty() { + // Not empty + let buffer_full = build_test_buffer::(1, 2, &[10]); + let view_full = ListView::::unpack(&buffer_full).unwrap(); + assert!(!view_full.is_empty()); + + // Empty + let buffer_empty = build_test_buffer::(0, 2, &[]); + let view_empty = ListView::::unpack(&buffer_empty).unwrap(); + assert!(view_empty.is_empty()); + } + + #[test] + fn test_iter() { + let items = [TestStruct(1), TestStruct(2)]; + let buffer = build_test_buffer::(items.len(), 3, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + let mut iter = view.iter(); + assert_eq!(iter.next(), Some(&items[0])); + assert_eq!(iter.next(), Some(&items[1])); + assert_eq!(iter.next(), None); + let collected: Vec<_> = view.iter().collect(); + assert_eq!(collected, vec![&items[0], &items[1]]); + } + + #[test] + fn test_iter_on_empty_list() { + let buffer = build_test_buffer::(0, 5, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.iter().count(), 0); + assert_eq!(view.iter().next(), None); + } + + #[test] + fn test_zero_capacity() { + // Buffer is just big enough for the header (len + padding), no data. + let buffer = build_test_buffer::(0, 0, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + assert_eq!(view.as_slice(), &[]); + } + + #[test] + fn test_with_padding() { + // Test the effect of padding by checking the total header size. + // T=AlignedStruct (align 16), L=PodU32 (size 4). + // The header size should be 16 (4 for len + 12 for padding). + let header_size = ListView::::size_of(0).unwrap(); + assert_eq!(header_size, 16); + + let items = [TestStruct(123), TestStruct(456)]; + let buffer = build_test_buffer::(items.len(), 4, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // Check if the public API works as expected despite internal padding + assert_eq!(view.len(), 2); + assert_eq!(view.capacity(), 4); + assert_eq!(view.as_slice(), &items[..]); + } + + #[test] + fn test_bytes_used_and_allocated() { + // 3 live elements, capacity 5 + let items = [10u32, 20, 30]; + let capacity = 5; + let buffer = build_test_buffer::(items.len(), capacity, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + let expected_used = ListView::::size_of(view.len()).unwrap(); + let expected_cap = ListView::::size_of(view.capacity()).unwrap(); + + assert_eq!(view.bytes_used().unwrap(), expected_used); + assert_eq!(view.bytes_allocated().unwrap(), expected_cap); + } +} diff --git a/pod/src/list/mod.rs b/pod/src/list/mod.rs new file mode 100644 index 00000000..56062237 --- /dev/null +++ b/pod/src/list/mod.rs @@ -0,0 +1,9 @@ +mod list_trait; +mod list_view; +mod list_view_mut; +mod list_view_read_only; + +pub use { + list_trait::List, list_view::ListView, list_view_mut::ListViewMut, + list_view_read_only::ListViewReadOnly, +}; diff --git a/pod/src/pod_length.rs b/pod/src/pod_length.rs new file mode 100644 index 00000000..756e4b8b --- /dev/null +++ b/pod/src/pod_length.rs @@ -0,0 +1,40 @@ +use { + crate::{ + error::PodSliceError, + primitives::{PodU16, PodU32, PodU64}, + }, + bytemuck::Pod, +}; + +/// Marker trait for converting to/from Pod `uint`'s and `usize` +pub trait PodLength: Pod + Into + TryFrom {} + +/// Blanket implementation to automatically implement `PodLength` for any type +/// that satisfies the required bounds. +impl PodLength for T where T: Pod + Into + TryFrom {} + +/// Implements the `TryFrom` and `From for usize` conversions for a Pod integer type +macro_rules! impl_pod_length_for { + ($PodType:ty, $PrimitiveType:ty) => { + impl TryFrom for $PodType { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + let primitive_val = <$PrimitiveType>::try_from(val)?; + Ok(primitive_val.into()) + } + } + + impl From<$PodType> for usize { + fn from(pod_val: $PodType) -> Self { + let primitive_val = <$PrimitiveType>::from(pod_val); + Self::try_from(primitive_val) + .expect("value out of range for usize on this platform") + } + } + }; +} + +impl_pod_length_for!(PodU16, u16); +impl_pod_length_for!(PodU32, u32); +impl_pod_length_for!(PodU64, u64); diff --git a/pod/src/slice.rs b/pod/src/slice.rs index ca885aba..1810265b 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -2,89 +2,65 @@ use { crate::{ - bytemuck::{ - pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, - }, - error::PodSliceError, + list::{List, ListView, ListViewMut, ListViewReadOnly}, primitives::PodU32, }, bytemuck::Pod, solana_program_error::ProgramError, }; -const LENGTH_SIZE: usize = std::mem::size_of::(); +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." +)] /// Special type for using a slice of `Pod`s in a zero-copy way +#[allow(deprecated)] pub struct PodSlice<'data, T: Pod> { - length: &'data PodU32, - data: &'data [T], + inner: ListViewReadOnly<'data, T, PodU32>, } + +#[allow(deprecated)] impl<'data, T: Pod> PodSlice<'data, T> { /// Unpack the buffer into a slice pub fn unpack<'a>(data: &'a [u8]) -> Result where 'a: 'data, { - if data.len() < LENGTH_SIZE { - return Err(PodSliceError::BufferTooSmall.into()); - } - let (length, data) = data.split_at(LENGTH_SIZE); - let length = pod_from_bytes::(length)?; - let _max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes(data)?; - Ok(Self { length, data }) + let inner = ListView::::unpack(data)?; + Ok(Self { inner }) } /// Get the slice data pub fn data(&self) -> &[T] { - let length = u32::from(*self.length) as usize; - &self.data[..length] + let len = self.inner.len(); + &self.inner.data[..len] } /// Get the amount of bytes used by `num_items` pub fn size_of(num_items: usize) -> Result { - std::mem::size_of::() - .checked_mul(num_items) - .and_then(|len| len.checked_add(LENGTH_SIZE)) - .ok_or_else(|| PodSliceError::CalculationFailure.into()) + ListView::::size_of(num_items) } } -/// Special type for using a slice of mutable `Pod`s in a zero-copy way +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." +)] +/// Special type for using a slice of mutable `Pod`s in a zero-copy way. +/// Uses `ListView` under the hood. pub struct PodSliceMut<'data, T: Pod> { - length: &'data mut PodU32, - data: &'data mut [T], - max_length: usize, + inner: ListViewMut<'data, T, PodU32>, } -impl<'data, T: Pod> PodSliceMut<'data, T> { - /// Unpack the mutable buffer into a mutable slice, with the option to - /// initialize the data - fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result - where - 'a: 'data, - { - if data.len() < LENGTH_SIZE { - return Err(PodSliceError::BufferTooSmall.into()); - } - let (length, data) = data.split_at_mut(LENGTH_SIZE); - let length = pod_from_bytes_mut::(length)?; - if init { - *length = 0.into(); - } - let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes_mut(data)?; - Ok(Self { - length, - data, - max_length, - }) - } +#[allow(deprecated)] +impl<'data, T: Pod> PodSliceMut<'data, T> { /// Unpack the mutable buffer into a mutable slice pub fn unpack<'a>(data: &'a mut [u8]) -> Result where 'a: 'data, { - Self::unpack_internal(data, /* init */ false) + let inner = ListView::::unpack_mut(data)?; + Ok(Self { inner }) } /// Unpack the mutable buffer into a mutable slice, and initialize the @@ -93,53 +69,22 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ true) + let inner = ListView::::init(data)?; + Ok(Self { inner }) } /// Add another item to the slice pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = u32::from(*self.length); - if length as usize == self.max_length { - Err(PodSliceError::BufferTooSmall.into()) - } else { - self.data[length as usize] = t; - *self.length = length.saturating_add(1).into(); - Ok(()) - } + self.inner.push(t) } } -fn max_len_for_type(data_len: usize, length_val: usize) -> Result { - let item_size = std::mem::size_of::(); - let max_len = data_len - .checked_div(item_size) - .ok_or(PodSliceError::CalculationFailure)?; - - // Make sure the max length that can be stored in the buffer isn't less - // than the length value. - if max_len < length_val { - Err(PodSliceError::BufferTooSmall)? - } - - // Make sure the buffer is cleanly divisible by `size_of::`; not over or - // under allocated. - if max_len.saturating_mul(item_size) != data_len { - if max_len == 0 { - // Size of T is greater than buffer size - Err(PodSliceError::BufferTooSmall)? - } else { - Err(PodSliceError::BufferTooLarge)? - } - } - - Ok(max_len) -} - #[cfg(test)] +#[allow(deprecated)] mod tests { use { super::*, - crate::bytemuck::pod_slice_to_bytes, + crate::{bytemuck::pod_slice_to_bytes, error::PodSliceError}, bytemuck_derive::{Pod, Zeroable}, }; @@ -150,6 +95,8 @@ mod tests { test_pubkey: [u8; 32], } + const LENGTH_SIZE: usize = std::mem::size_of::(); + #[test] fn test_pod_slice() { let test_field_bytes = [0]; @@ -170,7 +117,7 @@ mod tests { let pod_slice = PodSlice::::unpack(&pod_slice_bytes).unwrap(); let pod_slice_data = pod_slice.data(); - assert_eq!(*pod_slice.length, PodU32::from(2)); + assert_eq!(pod_slice.inner.len(), 2); assert_eq!(pod_slice_to_bytes(pod_slice.data()), data_bytes); assert_eq!(pod_slice_data[0].test_field, test_field_bytes[0]); assert_eq!(pod_slice_data[0].test_pubkey, test_pubkey_bytes); @@ -187,11 +134,7 @@ mod tests { let err = PodSlice::::unpack(&pod_slice_bytes) .err() .unwrap(); - assert_eq!( - err, - PodSliceError::BufferTooLarge.into(), - "Expected an `PodSliceError::BufferTooLarge` error" - ); + assert!(matches!(err, ProgramError::InvalidArgument)); } #[test] @@ -211,7 +154,7 @@ mod tests { data[..LENGTH_SIZE].copy_from_slice(&length_le); let pod_slice = PodSlice::::unpack(&data).unwrap(); - let pod_slice_len = u32::from(*pod_slice.length); + let pod_slice_len = pod_slice.inner.len() as u32; let data = pod_slice.data(); let data_vec = data.to_vec(); @@ -228,11 +171,7 @@ mod tests { let err = PodSlice::::unpack(&pod_slice_bytes) .err() .unwrap(); - assert_eq!( - err, - PodSliceError::BufferTooSmall.into(), - "Expected an `PodSliceError::BufferTooSmall` error" - ); + assert!(matches!(err, ProgramError::InvalidArgument)); } #[test] @@ -270,9 +209,10 @@ mod tests { let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(1)); + assert_eq!(pod_slice.inner.len(), 1); pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(2)); + assert_eq!(pod_slice.inner.len(), 2); + let err = pod_slice .push(TestStruct::default()) .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index f17ee3d0..af01ced9 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -7,7 +7,10 @@ use { solana_program_error::ProgramError, solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, - spl_pod::slice::{PodSlice, PodSliceMut}, + spl_pod::{ + list::{self, List, ListView}, + primitives::PodU32, + }, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, std::future::Future, }; @@ -170,9 +173,9 @@ impl ExtraAccountMetaList { extra_account_metas: &[ExtraAccountMeta], ) -> Result<(), ProgramError> { let mut state = TlvStateMut::unpack(data).unwrap(); - let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; + let tlv_size = ListView::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -186,31 +189,31 @@ impl ExtraAccountMetaList { extra_account_metas: &[ExtraAccountMeta], ) -> Result<(), ProgramError> { let mut state = TlvStateMut::unpack(data).unwrap(); - let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; + let tlv_size = ListView::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } Ok(()) } - /// Get the underlying `PodSlice` from an unpacked TLV + /// Get the underlying `ListViewReadOnly` from an unpacked TLV /// /// Due to lifetime annoyances, this function can't just take in the bytes, /// since then we would be returning a reference to a locally created /// `TlvStateBorrowed`. I hope there's a better way to do this! pub fn unpack_with_tlv_state<'a, T: SplDiscriminate>( tlv_state: &'a TlvStateBorrowed, - ) -> Result, ProgramError> { + ) -> Result, ProgramError> { let bytes = tlv_state.get_first_bytes::()?; - PodSlice::::unpack(bytes) + ListView::::unpack(bytes) } /// Get the byte size required to hold `num_items` items pub fn size_of(num_items: usize) -> Result { Ok(TlvStateBorrowed::get_base_len() - .saturating_add(PodSlice::::size_of(num_items)?)) + .saturating_add(ListView::::size_of(num_items)?)) } /// Checks provided account infos against validation data, using @@ -227,7 +230,7 @@ impl ExtraAccountMetaList { ) -> Result<(), ProgramError> { let state = TlvStateBorrowed::unpack(data).unwrap(); let extra_meta_list = ExtraAccountMetaList::unpack_with_tlv_state::(&state)?; - let extra_account_metas = extra_meta_list.data(); + let extra_account_metas = extra_meta_list.as_slice(); let initial_accounts_len = account_infos.len() - extra_account_metas.len(); @@ -281,7 +284,7 @@ impl ExtraAccountMetaList { { let state = TlvStateBorrowed::unpack(data)?; let bytes = state.get_first_bytes::()?; - let extra_account_metas = PodSlice::::unpack(bytes)?; + let extra_account_metas = ListView::::unpack(bytes)?; // Fetch account data for each of the instruction accounts let mut account_key_datas = vec![]; @@ -294,7 +297,7 @@ impl ExtraAccountMetaList { account_key_datas.push((meta.pubkey, account_data)); } - for extra_meta in extra_account_metas.data().iter() { + for extra_meta in extra_account_metas.as_slice().iter() { let mut meta = extra_meta.resolve(&instruction.data, &instruction.program_id, |usize| { account_key_datas @@ -326,9 +329,9 @@ impl ExtraAccountMetaList { ) -> Result<(), ProgramError> { let state = TlvStateBorrowed::unpack(data)?; let bytes = state.get_first_bytes::()?; - let extra_account_metas = PodSlice::::unpack(bytes)?; + let extra_account_metas = ListView::::unpack(bytes)?; - for extra_meta in extra_account_metas.data().iter() { + for extra_meta in extra_account_metas.as_slice().iter() { let mut meta = { // Create a list of `Ref`s so we can reference account data in the // resolution step @@ -1460,7 +1463,7 @@ mod tests { let state = TlvStateBorrowed::unpack(buffer).unwrap(); let unpacked_metas_pod = ExtraAccountMetaList::unpack_with_tlv_state::(&state).unwrap(); - let unpacked_metas = unpacked_metas_pod.data(); + let unpacked_metas = unpacked_metas_pod.as_slice(); assert_eq!( unpacked_metas, updated_metas, "The ExtraAccountMetas in the buffer should match the expected ones."