From b8544e5fab47c23035e66c3344e4c916a6d6f541 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Fri, 6 Jun 2025 14:11:46 +0200 Subject: [PATCH 1/7] Add support for PodList --- pod/Cargo.toml | 2 +- pod/src/lib.rs | 1 + pod/src/list.rs | 334 ++++++++++++++++++++++++++++ pod/src/slice.rs | 9 +- tlv-account-resolution/Cargo.toml | 2 +- tlv-account-resolution/src/state.rs | 7 +- type-length-value/Cargo.toml | 2 +- 7 files changed, 350 insertions(+), 7 deletions(-) create mode 100644 pod/src/list.rs diff --git a/pod/Cargo.toml b/pod/Cargo.toml index 736ccd48..fc82131a 100644 --- a/pod/Cargo.toml +++ b/pod/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "spl-pod" -version = "0.5.1" +version = "0.6.0" description = "Solana Program Library Plain Old Data (Pod)" authors = ["Anza Maintainers "] repository = "https://github.com/solana-program/libraries" diff --git a/pod/src/lib.rs b/pod/src/lib.rs index e1aa65fd..b2871e83 100644 --- a/pod/src/lib.rs +++ b/pod/src/lib.rs @@ -2,6 +2,7 @@ pub mod bytemuck; pub mod error; +pub mod list; pub mod option; pub mod optional_keys; pub mod primitives; diff --git a/pod/src/list.rs b/pod/src/list.rs new file mode 100644 index 00000000..eae3e1d2 --- /dev/null +++ b/pod/src/list.rs @@ -0,0 +1,334 @@ +use crate::bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}; +use crate::error::PodSliceError; +use crate::primitives::PodU32; +use crate::slice::max_len_for_type; +use bytemuck::Pod; +use solana_program_error::ProgramError; + +const LENGTH_SIZE: usize = std::mem::size_of::(); + +/// A mutable, variable-length collection of `Pod` types backed by a byte buffer. +/// +/// `PodList` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. +/// It does not own the buffer itself, but acts as a mutable view over it. +/// +/// This is useful in environments where allocations are restricted or expensive, +/// such as Solana programs, allowing for dynamic-length data structures within a +/// fixed-size account. +/// +/// ## Memory Layout +/// +/// The structure assumes the underlying byte buffer is formatted as follows: +/// 1. **Length**: A `u32` value (`PodU32`) at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. +/// 2. **Data**: The remaining part of the buffer, which is treated as a slice +/// of `T` elements. The capacity of the collection is the number of `T` +/// elements that can fit into this data portion. +pub struct PodList<'data, T: Pod> { + length: &'data mut PodU32, + data: &'data mut [T], + max_length: usize, +} + +impl<'data, T: Pod> PodList<'data, T> { + /// Unpack the mutable buffer into a mutable slice, with the option to + /// initialize the data + fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result + where + 'a: 'data, + { + if data.len() < LENGTH_SIZE { + return Err(PodSliceError::BufferTooSmall.into()); + } + let (length, data) = data.split_at_mut(LENGTH_SIZE); + let length = pod_from_bytes_mut::(length)?; + if init { + *length = 0.into(); + } + let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; + let data = pod_slice_from_bytes_mut(data)?; + Ok(Self { + length, + data, + max_length, + }) + } + + /// Unpack the mutable buffer into a mutable slice + pub fn unpack<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, /* init */ false) + } + + /// Unpack the mutable buffer into a mutable slice, and initialize the + /// slice to 0-length + pub fn init<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + Self::unpack_internal(data, /* init */ true) + } + + /// Add another item to the slice + pub fn push(&mut self, t: T) -> Result<(), ProgramError> { + let length = u32::from(*self.length); + if length as usize == self.max_length { + Err(PodSliceError::BufferTooSmall.into()) + } else { + self.data[length as usize] = t; + *self.length = length.saturating_add(1).into(); + Ok(()) + } + } + + /// Remove and return the element at `index`, shifting all later + /// elements one position to the left. + pub fn remove_at(&mut self, index: usize) -> Result { + let len = u32::from(*self.length) as usize; + if index >= len { + return Err(ProgramError::InvalidArgument); + } + + let removed_item = self.data[index]; + + // Move the tail left by one + let tail_start = index + .checked_add(1) + .ok_or(ProgramError::ArithmeticOverflow)?; + self.data.copy_within(tail_start..len, index); + + // Zero-fill the now-unused slot at the end + let last = len.checked_sub(1).ok_or(ProgramError::ArithmeticOverflow)?; + self.data[last] = T::zeroed(); + + // Store the new length (len - 1) + *self.length = (last as u32).into(); + + Ok(removed_item) + } + + /// Find the first element that satisfies `predicate` and remove it, + /// returning the element. + pub fn remove_first_where

(&mut self, mut predicate: P) -> Result + where + P: FnMut(&T) -> bool, + { + if let Some(index) = self.data.iter().position(&mut predicate) { + self.remove_at(index) + } else { + Err(ProgramError::InvalidArgument) + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestStruct { + test_field: u8, + test_pubkey: [u8; 32], + } + + #[test] + fn test_pod_collection() { + // slice can fit 2 `TestStruct` + let mut pod_slice_bytes = [0; 70]; + // set length to 1, so we have room to push 1 more item + let len_bytes = [1, 0, 0, 0]; + pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + + let mut pod_slice = PodList::::unpack(&mut pod_slice_bytes).unwrap(); + + assert_eq!(*pod_slice.length, PodU32::from(1)); + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(*pod_slice.length, PodU32::from(2)); + let err = pod_slice + .push(TestStruct::default()) + .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + fn make_buffer(capacity: usize, items: &[u8]) -> Vec { + let buff_len = LENGTH_SIZE.checked_add(capacity).unwrap(); + let mut buf = vec![0u8; buff_len]; + buf[..LENGTH_SIZE].copy_from_slice(&(items.len() as u32).to_le_bytes()); + let end = LENGTH_SIZE.checked_add(items.len()).unwrap(); + buf[LENGTH_SIZE..end].copy_from_slice(items); + buf + } + + #[test] + fn remove_at_first_item() { + let mut buff = make_buffer(15, &[10, 20, 30, 40]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(0).unwrap(); + assert_eq!(removed, 10); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[20, 30, 40]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_at_middle_item() { + let mut buff = make_buffer(15, &[10, 20, 30, 40]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(2).unwrap(); + assert_eq!(removed, 30); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 40]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_at_last_item() { + let mut buff = make_buffer(15, &[10, 20, 30, 40]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(3).unwrap(); + assert_eq!(removed, 40); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 30]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_at_out_of_bounds() { + let mut buff = make_buffer(3, &[1, 2, 3]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_at(3).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + // pod_list should be unchanged + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), vec![1, 2, 3]); + } + + assert_eq!(buff, original_buff); + } + + #[test] + fn remove_at_single_element() { + let mut buff = make_buffer(1, &[10]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_at(0).unwrap(); + assert_eq!(removed, 10); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 0); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[] as &[u8]); + assert_eq!(pod_list.data[0], 0); + } + + #[test] + fn remove_at_empty_slice() { + let mut buff = make_buffer(0, &[]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_at(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + // Assert list state is unchanged + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 0); + } + + assert_eq!(buff, original_buff); + } + + #[test] + fn remove_first_where_first_item() { + let mut buff = make_buffer(3, &[5, 10, 15]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|&x| x == 5).unwrap(); + assert_eq!(removed, 5); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 2); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 15]); + assert_eq!(pod_list.data[2], 0); + } + + #[test] + fn remove_first_where_middle_item() { + let mut buff = make_buffer(4, &[1, 2, 3, 4]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|v| *v == 3).unwrap(); + assert_eq!(removed, 3); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 3); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[1, 2, 4]); + assert_eq!(pod_list.data[3], 0); + } + + #[test] + fn remove_first_where_last_item() { + let mut buff = make_buffer(3, &[5, 10, 15]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|&x| x == 15).unwrap(); + assert_eq!(removed, 15); + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 2); + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[5, 10]); + assert_eq!(pod_list.data[2], 0); + } + + #[test] + fn remove_first_where_multiple_matches() { + let mut buff = make_buffer(5, &[7, 8, 8, 9, 10]); + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let removed = pod_list.remove_first_where(|v| *v == 8).unwrap(); + assert_eq!(removed, 8); // Removed *first* 8 + let pod_list_len = u32::from(*pod_list.length) as usize; + assert_eq!(pod_list_len, 4); + // Should remove only the *first* match. + assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[7, 8, 9, 10]); + assert_eq!(pod_list.data[4], 0); + } + + #[test] + fn remove_first_where_not_found() { + let mut buff = make_buffer(3, &[5, 6, 7]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_first_where(|v| *v == 42).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(u32::from(*pod_list.length) as usize, 3); + } + + assert_eq!(buff, original_buff); + } + + #[test] + fn remove_first_where_empty_slice() { + let mut buff = make_buffer(0, &[]); + let original_buff = buff.clone(); + + { + let mut pod_list = PodList::::unpack(&mut buff).unwrap(); + let err = pod_list.remove_first_where(|_| true).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(u32::from(*pod_list.length) as usize, 0); + } + + assert_eq!(buff, original_buff); + } +} diff --git a/pod/src/slice.rs b/pod/src/slice.rs index ca885aba..0082c7a4 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -49,12 +49,18 @@ impl<'data, T: Pod> PodSlice<'data, T> { } } +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). Please use `PodList` instead." +)] /// Special type for using a slice of mutable `Pod`s in a zero-copy way pub struct PodSliceMut<'data, T: Pod> { length: &'data mut PodU32, data: &'data mut [T], max_length: usize, } + +#[allow(deprecated)] impl<'data, T: Pod> PodSliceMut<'data, T> { /// Unpack the mutable buffer into a mutable slice, with the option to /// initialize the data @@ -109,7 +115,7 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { } } -fn max_len_for_type(data_len: usize, length_val: usize) -> Result { +pub fn max_len_for_type(data_len: usize, length_val: usize) -> Result { let item_size = std::mem::size_of::(); let max_len = data_len .checked_div(item_size) @@ -136,6 +142,7 @@ fn max_len_for_type(data_len: usize, length_val: usize) -> Result::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = PodList::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -188,7 +189,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = PodSliceMut::init(bytes)?; + let mut validation_data = PodList::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } diff --git a/type-length-value/Cargo.toml b/type-length-value/Cargo.toml index f5d66470..7c7ba6d6 100644 --- a/type-length-value/Cargo.toml +++ b/type-length-value/Cargo.toml @@ -22,7 +22,7 @@ solana-msg = "2.2.1" solana-program-error = "2.2.2" spl-discriminator = { version = "0.4.0", path = "../discriminator" } spl-type-length-value-derive = { version = "0.2", path = "./derive", optional = true } -spl-pod = { version = "0.5.1", path = "../pod" } +spl-pod = { version = "0.6.0", path = "../pod" } thiserror = "2.0" [lib] From d832bc0097129f52cca29cc280d6057ed23b8211 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Wed, 9 Jul 2025 13:53:33 +0200 Subject: [PATCH 2/7] Review updates --- pod/Cargo.toml | 2 +- pod/src/list.rs | 451 +++++++++++++++++++--------- pod/src/primitives.rs | 41 +++ pod/src/slice.rs | 82 +++-- tlv-account-resolution/Cargo.toml | 2 +- tlv-account-resolution/src/state.rs | 7 +- type-length-value/Cargo.toml | 2 +- 7 files changed, 391 insertions(+), 196 deletions(-) diff --git a/pod/Cargo.toml b/pod/Cargo.toml index fc82131a..736ccd48 100644 --- a/pod/Cargo.toml +++ b/pod/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "spl-pod" -version = "0.6.0" +version = "0.5.1" description = "Solana Program Library Plain Old Data (Pod)" authors = ["Anza Maintainers "] repository = "https://github.com/solana-program/libraries" diff --git a/pod/src/list.rs b/pod/src/list.rs index eae3e1d2..480761bd 100644 --- a/pod/src/list.rs +++ b/pod/src/list.rs @@ -1,15 +1,37 @@ -use crate::bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}; -use crate::error::PodSliceError; -use crate::primitives::PodU32; -use crate::slice::max_len_for_type; -use bytemuck::Pod; -use solana_program_error::ProgramError; - -const LENGTH_SIZE: usize = std::mem::size_of::(); +use { + crate::{ + bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}, + error::PodSliceError, + primitives::{PodLength, PodU64}, + }, + bytemuck::Pod, + core::mem::{align_of, size_of}, + solana_program_error::ProgramError, +}; + +/// Calculate padding needed between types for alignment +#[inline] +fn calculate_padding() -> Result { + let length_size = size_of::(); + let data_align = align_of::(); + + // Calculate how many bytes we need to add to length_size + // to make it a multiple of data_align + let remainder = length_size + .checked_rem(data_align) + .ok_or(ProgramError::ArithmeticOverflow)?; + if remainder == 0 { + Ok(0) + } else { + data_align + .checked_sub(remainder) + .ok_or(ProgramError::ArithmeticOverflow) + } +} /// A mutable, variable-length collection of `Pod` types backed by a byte buffer. /// -/// `PodList` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of /// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. /// It does not own the buffer itself, but acts as a mutable view over it. /// @@ -20,34 +42,54 @@ const LENGTH_SIZE: usize = std::mem::size_of::(); /// ## Memory Layout /// /// The structure assumes the underlying byte buffer is formatted as follows: -/// 1. **Length**: A `u32` value (`PodU32`) at the beginning of the buffer, -/// indicating the number of currently active elements in the collection. -/// 2. **Data**: The remaining part of the buffer, which is treated as a slice +/// 1. **Length**: A length field of type `L` at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. Defaults to `PodU64` so the offset is then compatible with 1, 2, 4 and 8 bytes. +/// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. +/// 3. **Data**: The remaining part of the buffer, which is treated as a slice /// of `T` elements. The capacity of the collection is the number of `T` /// elements that can fit into this data portion. -pub struct PodList<'data, T: Pod> { - length: &'data mut PodU32, +pub struct ListView<'data, T: Pod, L: PodLength = PodU64> { + length: &'data mut L, data: &'data mut [T], max_length: usize, } -impl<'data, T: Pod> PodList<'data, T> { +impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { /// Unpack the mutable buffer into a mutable slice, with the option to /// initialize the data - fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result - where - 'a: 'data, - { - if data.len() < LENGTH_SIZE { + #[inline(always)] + fn unpack_internal(buf: &'data mut [u8], init: bool) -> Result { + // Split the buffer to get the length prefix. + // buf: [ L L L L | P P D D D D D D D D ...] + // <-------> <----------------------> + // len_bytes tail + let length_size = size_of::(); + if buf.len() < length_size { return Err(PodSliceError::BufferTooSmall.into()); } - let (length, data) = data.split_at_mut(LENGTH_SIZE); - let length = pod_from_bytes_mut::(length)?; + let (len_bytes, tail) = buf.split_at_mut(length_size); + + // Skip alignment padding to find the start of the data. + // tail: [P P | D D D D D D D D ...] + // <-> <-------------------> + // padding data_bytes + let padding = calculate_padding::()?; + let data_bytes = tail + .get_mut(padding..) + .ok_or(PodSliceError::BufferTooSmall)?; + + // Cast the bytes to typed data + let length = pod_from_bytes_mut::(len_bytes)?; + let data = pod_slice_from_bytes_mut::(data_bytes)?; + let max_length = data.len(); + + // Initialize the list or validate its current length. if init { - *length = 0.into(); + *length = L::from_usize(0)?; + } else if length.as_usize() > max_length { + return Err(PodSliceError::BufferTooSmall.into()); } - let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes_mut(data)?; + Ok(Self { length, data, @@ -60,7 +102,7 @@ impl<'data, T: Pod> PodList<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ false) + Self::unpack_internal(data, false) } /// Unpack the mutable buffer into a mutable slice, and initialize the @@ -69,17 +111,17 @@ impl<'data, T: Pod> PodList<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ true) + Self::unpack_internal(data, true) } /// Add another item to the slice pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = u32::from(*self.length); - if length as usize == self.max_length { + let length = self.length.as_usize(); + if length == self.max_length { Err(PodSliceError::BufferTooSmall.into()) } else { - self.data[length as usize] = t; - *self.length = length.saturating_add(1).into(); + self.data[length] = t; + *self.length = L::from_usize(length.saturating_add(1))?; Ok(()) } } @@ -87,7 +129,7 @@ impl<'data, T: Pod> PodList<'data, T> { /// Remove and return the element at `index`, shifting all later /// elements one position to the left. pub fn remove_at(&mut self, index: usize) -> Result { - let len = u32::from(*self.length) as usize; + let len = self.length.as_usize(); if index >= len { return Err(ProgramError::InvalidArgument); } @@ -101,33 +143,58 @@ impl<'data, T: Pod> PodList<'data, T> { self.data.copy_within(tail_start..len, index); // Zero-fill the now-unused slot at the end - let last = len.checked_sub(1).ok_or(ProgramError::ArithmeticOverflow)?; + let last = len.saturating_sub(1); self.data[last] = T::zeroed(); // Store the new length (len - 1) - *self.length = (last as u32).into(); + *self.length = L::from_usize(last)?; Ok(removed_item) } /// Find the first element that satisfies `predicate` and remove it, /// returning the element. - pub fn remove_first_where

(&mut self, mut predicate: P) -> Result + pub fn remove_first_where

(&mut self, predicate: P) -> Result where P: FnMut(&T) -> bool, { - if let Some(index) = self.data.iter().position(&mut predicate) { + if let Some(index) = self.data.iter().position(predicate) { self.remove_at(index) } else { Err(ProgramError::InvalidArgument) } } + + /// Get the amount of bytes used by `num_items` + pub fn size_of(num_items: usize) -> Result { + let padding_size = calculate_padding::()?; + let header_size = size_of::().saturating_add(padding_size); + + let data_size = size_of::() + .checked_mul(num_items) + .ok_or(PodSliceError::CalculationFailure)?; + + header_size + .checked_add(data_size) + .ok_or(PodSliceError::CalculationFailure.into()) + } + + /// Get the current number of items in collection + pub fn len(&self) -> usize { + self.length.as_usize() + } + + /// Returns true if the collection is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } #[cfg(test)] mod tests { use { super::*, + crate::primitives::{PodU16, PodU32, PodU64}, bytemuck_derive::{Pod, Zeroable}, }; @@ -139,196 +206,292 @@ mod tests { } #[test] - fn test_pod_collection() { - // slice can fit 2 `TestStruct` - let mut pod_slice_bytes = [0; 70]; - // set length to 1, so we have room to push 1 more item - let len_bytes = [1, 0, 0, 0]; - pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + fn init_and_push() { + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; - let mut pod_slice = PodList::::unpack(&mut pod_slice_bytes).unwrap(); + let mut pod_slice = ListView::::init(&mut buffer).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(1)); pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(2)); - let err = pod_slice - .push(TestStruct::default()) - .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); + assert_eq!(*pod_slice.length, PodU64::from(1)); + assert_eq!(pod_slice.len(), 1); + + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(*pod_slice.length, PodU64::from(2)); + assert_eq!(pod_slice.len(), 2); + + // Buffer should be full now + let err = pod_slice.push(TestStruct::default()).unwrap_err(); assert_eq!(err, PodSliceError::BufferTooSmall.into()); } - fn make_buffer(capacity: usize, items: &[u8]) -> Vec { - let buff_len = LENGTH_SIZE.checked_add(capacity).unwrap(); + fn make_buffer(capacity: usize, items: &[u8]) -> Vec { + let length_size = size_of::(); + let padding_size = calculate_padding::().unwrap(); + let header_size = length_size.saturating_add(padding_size); + let buff_len = header_size.checked_add(capacity).unwrap(); let mut buf = vec![0u8; buff_len]; - buf[..LENGTH_SIZE].copy_from_slice(&(items.len() as u32).to_le_bytes()); - let end = LENGTH_SIZE.checked_add(items.len()).unwrap(); - buf[LENGTH_SIZE..end].copy_from_slice(items); + + // Write the length + let length = L::from_usize(items.len()).unwrap(); + let length_bytes = bytemuck::bytes_of(&length); + buf[..length_size].copy_from_slice(length_bytes); + + // Copy the data after the header + let data_end = header_size.checked_add(items.len()).unwrap(); + buf[header_size..data_end].copy_from_slice(items); buf } #[test] fn remove_at_first_item() { - let mut buff = make_buffer(15, &[10, 20, 30, 40]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(0).unwrap(); + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(0).unwrap(); assert_eq!(removed, 10); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[20, 30, 40]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[20, 30, 40]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_at_middle_item() { - let mut buff = make_buffer(15, &[10, 20, 30, 40]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(2).unwrap(); + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(2).unwrap(); assert_eq!(removed, 30); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 40]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 40]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_at_last_item() { - let mut buff = make_buffer(15, &[10, 20, 30, 40]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(3).unwrap(); + let mut buff = make_buffer::(15, &[10, 20, 30, 40]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(3).unwrap(); assert_eq!(removed, 40); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 20, 30]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 30]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_at_out_of_bounds() { - let mut buff = make_buffer(3, &[1, 2, 3]); + let mut buff = make_buffer::(3, &[1, 2, 3]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_at(3).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_at(3).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); - // pod_list should be unchanged - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), vec![1, 2, 3]); - } + // list_view should be unchanged + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), vec![1, 2, 3]); assert_eq!(buff, original_buff); } #[test] fn remove_at_single_element() { - let mut buff = make_buffer(1, &[10]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_at(0).unwrap(); + let mut buff = make_buffer::(1, &[10]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_at(0).unwrap(); assert_eq!(removed, 10); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 0); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[] as &[u8]); - assert_eq!(pod_list.data[0], 0); + assert_eq!(list_view.len(), 0); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[] as &[u8]); + assert_eq!(list_view.data[0], 0); } #[test] fn remove_at_empty_slice() { - let mut buff = make_buffer(0, &[]); + let mut buff = make_buffer::(0, &[]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_at(0).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_at(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 0); - } + // Assert list state is unchanged + assert_eq!(list_view.len(), 0); assert_eq!(buff, original_buff); } #[test] fn remove_first_where_first_item() { - let mut buff = make_buffer(3, &[5, 10, 15]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|&x| x == 5).unwrap(); + let mut buff = make_buffer::(3, &[5, 10, 15]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|&x| x == 5).unwrap(); assert_eq!(removed, 5); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 2); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[10, 15]); - assert_eq!(pod_list.data[2], 0); + assert_eq!(list_view.len(), 2); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 15]); + assert_eq!(list_view.data[2], 0); } #[test] fn remove_first_where_middle_item() { - let mut buff = make_buffer(4, &[1, 2, 3, 4]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|v| *v == 3).unwrap(); + let mut buff = make_buffer::(4, &[1, 2, 3, 4]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|v| *v == 3).unwrap(); assert_eq!(removed, 3); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 3); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[1, 2, 4]); - assert_eq!(pod_list.data[3], 0); + assert_eq!(list_view.len(), 3); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[1, 2, 4]); + assert_eq!(list_view.data[3], 0); } #[test] fn remove_first_where_last_item() { - let mut buff = make_buffer(3, &[5, 10, 15]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|&x| x == 15).unwrap(); + let mut buff = make_buffer::(3, &[5, 10, 15]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|&x| x == 15).unwrap(); assert_eq!(removed, 15); - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 2); - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[5, 10]); - assert_eq!(pod_list.data[2], 0); + assert_eq!(list_view.len(), 2); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[5, 10]); + assert_eq!(list_view.data[2], 0); } #[test] fn remove_first_where_multiple_matches() { - let mut buff = make_buffer(5, &[7, 8, 8, 9, 10]); - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let removed = pod_list.remove_first_where(|v| *v == 8).unwrap(); + let mut buff = make_buffer::(5, &[7, 8, 8, 9, 10]); + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let removed = list_view.remove_first_where(|v| *v == 8).unwrap(); assert_eq!(removed, 8); // Removed *first* 8 - let pod_list_len = u32::from(*pod_list.length) as usize; - assert_eq!(pod_list_len, 4); + assert_eq!(list_view.len(), 4); // Should remove only the *first* match. - assert_eq!(pod_list.data[..pod_list_len].to_vec(), &[7, 8, 9, 10]); - assert_eq!(pod_list.data[4], 0); + assert_eq!(list_view.data[..list_view.len()].to_vec(), &[7, 8, 9, 10]); + assert_eq!(list_view.data[4], 0); } #[test] fn remove_first_where_not_found() { - let mut buff = make_buffer(3, &[5, 6, 7]); + let mut buff = make_buffer::(3, &[5, 6, 7]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_first_where(|v| *v == 42).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - assert_eq!(u32::from(*pod_list.length) as usize, 3); - } + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_first_where(|v| *v == 42).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(list_view.len(), 3); assert_eq!(buff, original_buff); } #[test] fn remove_first_where_empty_slice() { - let mut buff = make_buffer(0, &[]); + let mut buff = make_buffer::(0, &[]); let original_buff = buff.clone(); - { - let mut pod_list = PodList::::unpack(&mut buff).unwrap(); - let err = pod_list.remove_first_where(|_| true).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - assert_eq!(u32::from(*pod_list.length) as usize, 0); - } + let mut list_view = ListView::::unpack(&mut buff).unwrap(); + let err = list_view.remove_first_where(|_| true).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + // Assert list state is unchanged + assert_eq!(list_view.len(), 0); assert_eq!(buff, original_buff); } + + #[test] + fn test_different_length_types() { + // Test with u16 length + let mut buff16 = make_buffer::(5, &[1, 2, 3]); + let list16 = ListView::::unpack(&mut buff16).unwrap(); + assert_eq!(list16.length.as_usize(), 3); + assert_eq!(list16.len(), 3); + + // Test with u32 length + let mut buff32 = make_buffer::(5, &[4, 5, 6]); + let list32 = ListView::::unpack(&mut buff32).unwrap(); + assert_eq!(list32.length.as_usize(), 3); + assert_eq!(list32.len(), 3); + + // Test with u64 length + let mut buff64 = make_buffer::(5, &[7, 8, 9]); + let list64 = ListView::::unpack(&mut buff64).unwrap(); + assert_eq!(list64.length.as_usize(), 3); + assert_eq!(list64.len(), 3); + } + + #[test] + fn test_calculate_padding() { + // When length and data have same alignment, no padding needed + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + + // When data alignment is smaller than or divides length size + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + assert_eq!(calculate_padding::().unwrap(), 0); + + // When padding is needed + assert_eq!(calculate_padding::().unwrap(), 2); // 2 + 2 = 4 (align to 4) + assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 (align to 8) + assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 (align to 8) + + // Test with custom aligned structs + #[repr(C, align(8))] + #[derive(Pod, Zeroable, Copy, Clone)] + struct Align8 { + _data: [u8; 8], + } + + #[repr(C, align(16))] + #[derive(Pod, Zeroable, Copy, Clone)] + struct Align16 { + _data: [u8; 16], + } + + assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 + assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 + assert_eq!(calculate_padding::().unwrap(), 0); // 8 % 8 = 0 + + assert_eq!(calculate_padding::().unwrap(), 14); // 2 + 14 = 16 + assert_eq!(calculate_padding::().unwrap(), 12); // 4 + 12 = 16 + assert_eq!(calculate_padding::().unwrap(), 8); // 8 + 8 = 16 + } + + #[test] + fn test_alignment_in_practice() { + // u32 length with u64 data - needs 4 bytes padding + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; + let list = ListView::::init(&mut buffer).unwrap(); + + // Check that data pointer is 8-byte aligned + let data_ptr = list.data.as_ptr() as usize; + assert_eq!(data_ptr % 8, 0); + + // u16 length with u64 data - needs 6 bytes padding + let size = ListView::::size_of(2).unwrap(); + let mut buffer = vec![0u8; size]; + let list = ListView::::init(&mut buffer).unwrap(); + + let data_ptr = list.data.as_ptr() as usize; + assert_eq!(data_ptr % 8, 0); + } + + #[test] + fn test_length_too_large() { + // Create a buffer with capacity for 2 items + let capacity = 2; + let length_size = size_of::(); + let padding_size = calculate_padding::().unwrap(); + let header_size = length_size.saturating_add(padding_size); + let buff_len = header_size.checked_add(capacity).unwrap(); + let mut buffer = vec![0u8; buff_len]; + + // Manually write a length value that exceeds the capacity + let invalid_length = PodU32::from_usize(capacity + 1).unwrap(); + let length_bytes = bytemuck::bytes_of(&invalid_length); + buffer[..length_size].copy_from_slice(length_bytes); + + // Attempting to unpack should return BufferTooSmall error + match ListView::::unpack(&mut buffer) { + Err(err) => assert_eq!(err, PodSliceError::BufferTooSmall.into()), + Ok(_) => panic!("Expected BufferTooSmall error, but unpack succeeded"), + } + } } diff --git a/pod/src/primitives.rs b/pod/src/primitives.rs index 5eb694ed..4ae28957 100644 --- a/pod/src/primitives.rs +++ b/pod/src/primitives.rs @@ -127,6 +127,47 @@ impl_int_conversion!(PodI64, i64); pub struct PodU128(pub [u8; 16]); impl_int_conversion!(PodU128, u128); +/// Trait for types that can be used as length fields in Pod data structures +pub trait PodLength: bytemuck::Pod + Copy { + fn as_usize(&self) -> usize; + + fn from_usize(val: usize) -> Result; +} + +impl PodLength for PodU16 { + fn as_usize(&self) -> usize { + u16::from(*self) as usize + } + + fn from_usize(val: usize) -> Result { + u16::try_from(val) + .map(Into::into) + .map_err(|_| crate::error::PodSliceError::CalculationFailure) + } +} + +impl PodLength for PodU32 { + fn as_usize(&self) -> usize { + u32::from(*self) as usize + } + + fn from_usize(val: usize) -> Result { + u32::try_from(val) + .map(Into::into) + .map_err(|_| crate::error::PodSliceError::CalculationFailure) + } +} + +impl PodLength for PodU64 { + fn as_usize(&self) -> usize { + u64::from(*self) as usize + } + + fn from_usize(val: usize) -> Result { + Ok((val as u64).into()) + } +} + #[cfg(test)] mod tests { use {super::*, crate::bytemuck::pod_from_bytes}; diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 0082c7a4..8ee4826d 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -2,10 +2,9 @@ use { crate::{ - bytemuck::{ - pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, - }, + bytemuck::{pod_from_bytes, pod_slice_from_bytes}, error::PodSliceError, + list::ListView, primitives::PodU32, }, bytemuck::Pod, @@ -51,46 +50,23 @@ impl<'data, T: Pod> PodSlice<'data, T> { #[deprecated( since = "0.6.0", - note = "This struct will be removed in the next major release (1.0.0). Please use `PodList` instead." + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." )] -/// Special type for using a slice of mutable `Pod`s in a zero-copy way +/// Special type for using a slice of mutable `Pod`s in a zero-copy way. +/// Uses `ListView` under the hood. pub struct PodSliceMut<'data, T: Pod> { - length: &'data mut PodU32, - data: &'data mut [T], - max_length: usize, + inner: ListView<'data, T, PodU32>, } #[allow(deprecated)] impl<'data, T: Pod> PodSliceMut<'data, T> { - /// Unpack the mutable buffer into a mutable slice, with the option to - /// initialize the data - fn unpack_internal<'a>(data: &'a mut [u8], init: bool) -> Result - where - 'a: 'data, - { - if data.len() < LENGTH_SIZE { - return Err(PodSliceError::BufferTooSmall.into()); - } - let (length, data) = data.split_at_mut(LENGTH_SIZE); - let length = pod_from_bytes_mut::(length)?; - if init { - *length = 0.into(); - } - let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes_mut(data)?; - Ok(Self { - length, - data, - max_length, - }) - } - /// Unpack the mutable buffer into a mutable slice pub fn unpack<'a>(data: &'a mut [u8]) -> Result where 'a: 'data, { - Self::unpack_internal(data, /* init */ false) + let inner = ListView::::unpack(data)?; + Ok(Self { inner }) } /// Unpack the mutable buffer into a mutable slice, and initialize the @@ -99,23 +75,17 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { where 'a: 'data, { - Self::unpack_internal(data, /* init */ true) + let inner = ListView::::init(data)?; + Ok(Self { inner }) } /// Add another item to the slice pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = u32::from(*self.length); - if length as usize == self.max_length { - Err(PodSliceError::BufferTooSmall.into()) - } else { - self.data[length as usize] = t; - *self.length = length.saturating_add(1).into(); - Ok(()) - } + self.inner.push(t) } } -pub fn max_len_for_type(data_len: usize, length_val: usize) -> Result { +fn max_len_for_type(data_len: usize, length_val: usize) -> Result { let item_size = std::mem::size_of::(); let max_len = data_len .checked_div(item_size) @@ -275,11 +245,33 @@ mod tests { let len_bytes = [1, 0, 0, 0]; pod_slice_bytes[0..4].copy_from_slice(&len_bytes); - let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); + // Verify initial length + assert_eq!( + u32::from_le_bytes([ + pod_slice_bytes[0], + pod_slice_bytes[1], + pod_slice_bytes[2], + pod_slice_bytes[3] + ]), + 1 + ); - assert_eq!(*pod_slice.length, PodU32::from(1)); + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU32::from(2)); + + // Check length after push + assert_eq!( + u32::from_le_bytes([ + pod_slice_bytes[0], + pod_slice_bytes[1], + pod_slice_bytes[2], + pod_slice_bytes[3] + ]), + 2 + ); + + // Test that buffer is full + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); let err = pod_slice .push(TestStruct::default()) .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); diff --git a/tlv-account-resolution/Cargo.toml b/tlv-account-resolution/Cargo.toml index 8639306e..12a4c38a 100644 --- a/tlv-account-resolution/Cargo.toml +++ b/tlv-account-resolution/Cargo.toml @@ -24,7 +24,7 @@ solana-program-error = "2.2.2" solana-pubkey = { version = "2.2.1", features = ["curve25519"] } spl-discriminator = { version = "0.4.0", path = "../discriminator" } spl-program-error = { version = "0.7.0", path = "../program-error" } -spl-pod = { version = "0.6.0", path = "../pod" } +spl-pod = { version = "0.5.1", path = "../pod" } spl-type-length-value = { version = "0.8.0", path = "../type-length-value" } thiserror = "2.0" diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index e901fff3..74ea8e36 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -1,6 +1,5 @@ //! State transition types -use spl_pod::list::PodList; use { crate::{account::ExtraAccountMeta, error::AccountResolutionError}, solana_account_info::AccountInfo, @@ -8,7 +7,7 @@ use { solana_program_error::ProgramError, solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, - spl_pod::slice::PodSlice, + spl_pod::{list::ListView, slice::PodSlice}, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, std::future::Future, }; @@ -173,7 +172,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = PodList::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -189,7 +188,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = PodList::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } diff --git a/type-length-value/Cargo.toml b/type-length-value/Cargo.toml index 7c7ba6d6..f5d66470 100644 --- a/type-length-value/Cargo.toml +++ b/type-length-value/Cargo.toml @@ -22,7 +22,7 @@ solana-msg = "2.2.1" solana-program-error = "2.2.2" spl-discriminator = { version = "0.4.0", path = "../discriminator" } spl-type-length-value-derive = { version = "0.2", path = "./derive", optional = true } -spl-pod = { version = "0.6.0", path = "../pod" } +spl-pod = { version = "0.5.1", path = "../pod" } thiserror = "2.0" [lib] From fa4e9253f33900f43ae01f9b7fe8b5ed5b95d4c9 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Wed, 9 Jul 2025 14:18:07 +0200 Subject: [PATCH 3/7] Fix test + add more specific doc string --- pod/src/slice.rs | 4 +++- tlv-account-resolution/src/state.rs | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 8ee4826d..34433466 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -50,7 +50,9 @@ impl<'data, T: Pod> PodSlice<'data, T> { #[deprecated( since = "0.6.0", - note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." + note = "This struct will be removed in the next major release (1.0.0). \ + Please use `ListView` instead. If using with existing data initialized by PodSliceMut, \ + you need to specifiy PodU32 length (e.g. ListView::::init(bytes))" )] /// Special type for using a slice of mutable `Pod`s in a zero-copy way. /// Uses `ListView` under the hood. diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index 74ea8e36..6bed11d6 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -7,7 +7,7 @@ use { solana_program_error::ProgramError, solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, - spl_pod::{list::ListView, slice::PodSlice}, + spl_pod::{list::ListView, primitives::PodU32, slice::PodSlice}, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, std::future::Future, }; @@ -172,7 +172,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = ListView::::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -188,7 +188,7 @@ impl ExtraAccountMetaList { let mut state = TlvStateMut::unpack(data).unwrap(); let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = ListView::::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } From f93022bd4693470266fa73fca7b1a8a5181ebbc8 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Tue, 15 Jul 2025 15:08:01 +0200 Subject: [PATCH 4/7] Review updates --- pod/src/error.rs | 12 +++- pod/src/lib.rs | 1 + pod/src/list.rs | 157 ++++++++++++------------------------------ pod/src/pod_length.rs | 54 +++++++++++++++ pod/src/primitives.rs | 41 ----------- 5 files changed, 110 insertions(+), 155 deletions(-) create mode 100644 pod/src/pod_length.rs diff --git a/pod/src/error.rs b/pod/src/error.rs index 2d83a231..b34fbf73 100644 --- a/pod/src/error.rs +++ b/pod/src/error.rs @@ -1,5 +1,9 @@ //! Error types -use solana_program_error::{ProgramError, ToStr}; +use { + solana_msg::msg, + solana_program_error::{ToStr, ProgramError}, + std::num::TryFromIntError, +}; /// Errors that may be returned by the spl-pod library. #[repr(u32)] @@ -39,3 +43,9 @@ impl ToStr for PodSliceError { } } } + +impl From for PodSliceError { + fn from(_: TryFromIntError) -> Self { + PodSliceError::CalculationFailure + } +} diff --git a/pod/src/lib.rs b/pod/src/lib.rs index b2871e83..b9a26da1 100644 --- a/pod/src/lib.rs +++ b/pod/src/lib.rs @@ -5,6 +5,7 @@ pub mod error; pub mod list; pub mod option; pub mod optional_keys; +pub mod pod_length; pub mod primitives; pub mod slice; diff --git a/pod/src/list.rs b/pod/src/list.rs index 480761bd..44ae527a 100644 --- a/pod/src/list.rs +++ b/pod/src/list.rs @@ -2,7 +2,8 @@ use { crate::{ bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}, error::PodSliceError, - primitives::{PodLength, PodU64}, + pod_length::PodLength, + primitives::PodU64, }, bytemuck::Pod, core::mem::{align_of, size_of}, @@ -29,15 +30,15 @@ fn calculate_padding() -> Result { } } -/// A mutable, variable-length collection of `Pod` types backed by a byte buffer. +/// An API for interpreting a raw buffer (`&[u8]`) as a mutable, variable-length collection of Pod elements. /// /// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of /// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. /// It does not own the buffer itself, but acts as a mutable view over it. /// /// This is useful in environments where allocations are restricted or expensive, -/// such as Solana programs, allowing for dynamic-length data structures within a -/// fixed-size account. +/// such as Solana programs, allowing for efficient reads and manipulation of +/// dynamic-length data structures. /// /// ## Memory Layout /// @@ -85,8 +86,8 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { // Initialize the list or validate its current length. if init { - *length = L::from_usize(0)?; - } else if length.as_usize() > max_length { + *length = L::try_from(0)?; + } else if (*length).into() > max_length { return Err(PodSliceError::BufferTooSmall.into()); } @@ -115,21 +116,21 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { } /// Add another item to the slice - pub fn push(&mut self, t: T) -> Result<(), ProgramError> { - let length = self.length.as_usize(); - if length == self.max_length { + pub fn push(&mut self, item: T) -> Result<(), ProgramError> { + let length = (*self.length).into(); + if length >= self.max_length { Err(PodSliceError::BufferTooSmall.into()) } else { - self.data[length] = t; - *self.length = L::from_usize(length.saturating_add(1))?; + self.data[length] = item; + *self.length = L::try_from(length.saturating_add(1))?; Ok(()) } } /// Remove and return the element at `index`, shifting all later /// elements one position to the left. - pub fn remove_at(&mut self, index: usize) -> Result { - let len = self.length.as_usize(); + pub fn remove(&mut self, index: usize) -> Result { + let len = (*self.length).into(); if index >= len { return Err(ProgramError::InvalidArgument); } @@ -147,24 +148,11 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { self.data[last] = T::zeroed(); // Store the new length (len - 1) - *self.length = L::from_usize(last)?; + *self.length = L::try_from(last)?; Ok(removed_item) } - /// Find the first element that satisfies `predicate` and remove it, - /// returning the element. - pub fn remove_first_where

(&mut self, predicate: P) -> Result - where - P: FnMut(&T) -> bool, - { - if let Some(index) = self.data.iter().position(predicate) { - self.remove_at(index) - } else { - Err(ProgramError::InvalidArgument) - } - } - /// Get the amount of bytes used by `num_items` pub fn size_of(num_items: usize) -> Result { let padding_size = calculate_padding::()?; @@ -181,13 +169,25 @@ impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { /// Get the current number of items in collection pub fn len(&self) -> usize { - self.length.as_usize() + (*self.length).into() } /// Returns true if the collection is empty pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns an iterator over the current elements + pub fn iter(&self) -> std::slice::Iter { + let len = (*self.length).into(); + self.data[..len].iter() + } + + /// Returns a mutable iterator over the current elements + pub fn iter_mut(&mut self) -> std::slice::IterMut { + let len = (*self.length).into(); + self.data[..len].iter_mut() + } } #[cfg(test)] @@ -225,7 +225,11 @@ mod tests { assert_eq!(err, PodSliceError::BufferTooSmall.into()); } - fn make_buffer(capacity: usize, items: &[u8]) -> Vec { + fn make_buffer + TryFrom>(capacity: usize, items: &[u8]) -> Vec + where + PodSliceError: From<>::Error>, + >::Error: std::fmt::Debug, + { let length_size = size_of::(); let padding_size = calculate_padding::().unwrap(); let header_size = length_size.saturating_add(padding_size); @@ -233,7 +237,7 @@ mod tests { let mut buf = vec![0u8; buff_len]; // Write the length - let length = L::from_usize(items.len()).unwrap(); + let length = L::try_from(items.len()).unwrap(); let length_bytes = bytemuck::bytes_of(&length); buf[..length_size].copy_from_slice(length_bytes); @@ -247,7 +251,7 @@ mod tests { fn remove_at_first_item() { let mut buff = make_buffer::(15, &[10, 20, 30, 40]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(0).unwrap(); + let removed = list_view.remove(0).unwrap(); assert_eq!(removed, 10); assert_eq!(list_view.len(), 3); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[20, 30, 40]); @@ -258,7 +262,7 @@ mod tests { fn remove_at_middle_item() { let mut buff = make_buffer::(15, &[10, 20, 30, 40]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(2).unwrap(); + let removed = list_view.remove(2).unwrap(); assert_eq!(removed, 30); assert_eq!(list_view.len(), 3); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 40]); @@ -269,7 +273,7 @@ mod tests { fn remove_at_last_item() { let mut buff = make_buffer::(15, &[10, 20, 30, 40]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(3).unwrap(); + let removed = list_view.remove(3).unwrap(); assert_eq!(removed, 40); assert_eq!(list_view.len(), 3); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 30]); @@ -282,7 +286,7 @@ mod tests { let original_buff = buff.clone(); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_at(3).unwrap_err(); + let err = list_view.remove(3).unwrap_err(); assert_eq!(err, ProgramError::InvalidArgument); // list_view should be unchanged @@ -296,7 +300,7 @@ mod tests { fn remove_at_single_element() { let mut buff = make_buffer::(1, &[10]); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_at(0).unwrap(); + let removed = list_view.remove(0).unwrap(); assert_eq!(removed, 10); assert_eq!(list_view.len(), 0); assert_eq!(list_view.data[..list_view.len()].to_vec(), &[] as &[u8]); @@ -309,82 +313,9 @@ mod tests { let original_buff = buff.clone(); let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_at(0).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - - // Assert list state is unchanged - assert_eq!(list_view.len(), 0); - - assert_eq!(buff, original_buff); - } - - #[test] - fn remove_first_where_first_item() { - let mut buff = make_buffer::(3, &[5, 10, 15]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|&x| x == 5).unwrap(); - assert_eq!(removed, 5); - assert_eq!(list_view.len(), 2); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 15]); - assert_eq!(list_view.data[2], 0); - } - - #[test] - fn remove_first_where_middle_item() { - let mut buff = make_buffer::(4, &[1, 2, 3, 4]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|v| *v == 3).unwrap(); - assert_eq!(removed, 3); - assert_eq!(list_view.len(), 3); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[1, 2, 4]); - assert_eq!(list_view.data[3], 0); - } - - #[test] - fn remove_first_where_last_item() { - let mut buff = make_buffer::(3, &[5, 10, 15]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|&x| x == 15).unwrap(); - assert_eq!(removed, 15); - assert_eq!(list_view.len(), 2); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[5, 10]); - assert_eq!(list_view.data[2], 0); - } - - #[test] - fn remove_first_where_multiple_matches() { - let mut buff = make_buffer::(5, &[7, 8, 8, 9, 10]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove_first_where(|v| *v == 8).unwrap(); - assert_eq!(removed, 8); // Removed *first* 8 - assert_eq!(list_view.len(), 4); - // Should remove only the *first* match. - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[7, 8, 9, 10]); - assert_eq!(list_view.data[4], 0); - } - - #[test] - fn remove_first_where_not_found() { - let mut buff = make_buffer::(3, &[5, 6, 7]); - let original_buff = buff.clone(); - - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_first_where(|v| *v == 42).unwrap_err(); + let err = list_view.remove(0).unwrap_err(); assert_eq!(err, ProgramError::InvalidArgument); - // Assert list state is unchanged - assert_eq!(list_view.len(), 3); - - assert_eq!(buff, original_buff); - } - - #[test] - fn remove_first_where_empty_slice() { - let mut buff = make_buffer::(0, &[]); - let original_buff = buff.clone(); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove_first_where(|_| true).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); // Assert list state is unchanged assert_eq!(list_view.len(), 0); @@ -396,19 +327,19 @@ mod tests { // Test with u16 length let mut buff16 = make_buffer::(5, &[1, 2, 3]); let list16 = ListView::::unpack(&mut buff16).unwrap(); - assert_eq!(list16.length.as_usize(), 3); + assert_eq!(list16.len(), 3); assert_eq!(list16.len(), 3); // Test with u32 length let mut buff32 = make_buffer::(5, &[4, 5, 6]); let list32 = ListView::::unpack(&mut buff32).unwrap(); - assert_eq!(list32.length.as_usize(), 3); + assert_eq!(list32.len(), 3); assert_eq!(list32.len(), 3); // Test with u64 length let mut buff64 = make_buffer::(5, &[7, 8, 9]); let list64 = ListView::::unpack(&mut buff64).unwrap(); - assert_eq!(list64.length.as_usize(), 3); + assert_eq!(list64.len(), 3); assert_eq!(list64.len(), 3); } @@ -484,7 +415,7 @@ mod tests { let mut buffer = vec![0u8; buff_len]; // Manually write a length value that exceeds the capacity - let invalid_length = PodU32::from_usize(capacity + 1).unwrap(); + let invalid_length = PodU32::try_from(capacity + 1).unwrap(); let length_bytes = bytemuck::bytes_of(&invalid_length); buffer[..length_size].copy_from_slice(length_bytes); diff --git a/pod/src/pod_length.rs b/pod/src/pod_length.rs new file mode 100644 index 00000000..52e4a7eb --- /dev/null +++ b/pod/src/pod_length.rs @@ -0,0 +1,54 @@ +use { + crate::{ + error::PodSliceError, + primitives::{PodU16, PodU32, PodU64}, + }, + bytemuck::Pod, +}; + +/// Marker trait for converting to/from Pod `uint`'s and `usize` +pub trait PodLength: Pod + Into + TryFrom {} + +impl PodLength for T where T: Pod + Into + TryFrom {} + +impl TryFrom for PodU16 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u16::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU16) -> Self { + u16::from(pod) as usize + } +} + +impl TryFrom for PodU32 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u32::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU32) -> Self { + u32::from(pod) as usize + } +} + +impl TryFrom for PodU64 { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + Ok(u64::try_from(val)?.into()) + } +} + +impl From for usize { + fn from(pod: PodU64) -> Self { + u64::from(pod) as usize + } +} diff --git a/pod/src/primitives.rs b/pod/src/primitives.rs index 4ae28957..5eb694ed 100644 --- a/pod/src/primitives.rs +++ b/pod/src/primitives.rs @@ -127,47 +127,6 @@ impl_int_conversion!(PodI64, i64); pub struct PodU128(pub [u8; 16]); impl_int_conversion!(PodU128, u128); -/// Trait for types that can be used as length fields in Pod data structures -pub trait PodLength: bytemuck::Pod + Copy { - fn as_usize(&self) -> usize; - - fn from_usize(val: usize) -> Result; -} - -impl PodLength for PodU16 { - fn as_usize(&self) -> usize { - u16::from(*self) as usize - } - - fn from_usize(val: usize) -> Result { - u16::try_from(val) - .map(Into::into) - .map_err(|_| crate::error::PodSliceError::CalculationFailure) - } -} - -impl PodLength for PodU32 { - fn as_usize(&self) -> usize { - u32::from(*self) as usize - } - - fn from_usize(val: usize) -> Result { - u32::try_from(val) - .map(Into::into) - .map_err(|_| crate::error::PodSliceError::CalculationFailure) - } -} - -impl PodLength for PodU64 { - fn as_usize(&self) -> usize { - u64::from(*self) as usize - } - - fn from_usize(val: usize) -> Result { - Ok((val as u64).into()) - } -} - #[cfg(test)] mod tests { use {super::*, crate::bytemuck::pod_from_bytes}; From 1e9df09d022f3ee75f254c209185e88ea2bfe731 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Thu, 17 Jul 2025 17:23:57 +0200 Subject: [PATCH 5/7] ListView entrypoint pattern --- .gitignore | 1 + pod/src/error.rs | 9 +- pod/src/list.rs | 428 -------------------- pod/src/list/list_view.rs | 584 ++++++++++++++++++++++++++++ pod/src/list/list_view_mut.rs | 306 +++++++++++++++ pod/src/list/list_view_read_only.rs | 156 ++++++++ pod/src/list/list_viewable.rs | 27 ++ pod/src/list/mod.rs | 9 + pod/src/pod_length.rs | 63 ++- pod/src/slice.rs | 115 ++---- tlv-account-resolution/src/state.rs | 29 +- 11 files changed, 1158 insertions(+), 569 deletions(-) delete mode 100644 pod/src/list.rs create mode 100644 pod/src/list/list_view.rs create mode 100644 pod/src/list/list_view_mut.rs create mode 100644 pod/src/list/list_view_read_only.rs create mode 100644 pod/src/list/list_viewable.rs create mode 100644 pod/src/list/mod.rs diff --git a/.gitignore b/.gitignore index 70487d22..4b79068f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ node_modules test-ledger dist +.idea diff --git a/pod/src/error.rs b/pod/src/error.rs index b34fbf73..03d5b70b 100644 --- a/pod/src/error.rs +++ b/pod/src/error.rs @@ -1,7 +1,6 @@ //! Error types use { - solana_msg::msg, - solana_program_error::{ToStr, ProgramError}, + solana_program_error::{ProgramError, ToStr}, std::num::TryFromIntError, }; @@ -26,6 +25,9 @@ pub enum PodSliceError { /// Provided byte buffer too large for expected type #[error("Provided byte buffer too large for expected type")] BufferTooLarge, + /// An integer conversion failed because the value was out of range for the target type + #[error("An integer conversion failed because the value was out of range for the target type")] + ValueOutOfRange, } impl From for ProgramError { @@ -40,12 +42,13 @@ impl ToStr for PodSliceError { PodSliceError::CalculationFailure => "Error in checked math operation", PodSliceError::BufferTooSmall => "Provided byte buffer too small for expected type", PodSliceError::BufferTooLarge => "Provided byte buffer too large for expected type", + PodSliceError::ValueOutOfRange => "An integer conversion failed because the value was out of range for the target type" } } } impl From for PodSliceError { fn from(_: TryFromIntError) -> Self { - PodSliceError::CalculationFailure + PodSliceError::ValueOutOfRange } } diff --git a/pod/src/list.rs b/pod/src/list.rs deleted file mode 100644 index 44ae527a..00000000 --- a/pod/src/list.rs +++ /dev/null @@ -1,428 +0,0 @@ -use { - crate::{ - bytemuck::{pod_from_bytes_mut, pod_slice_from_bytes_mut}, - error::PodSliceError, - pod_length::PodLength, - primitives::PodU64, - }, - bytemuck::Pod, - core::mem::{align_of, size_of}, - solana_program_error::ProgramError, -}; - -/// Calculate padding needed between types for alignment -#[inline] -fn calculate_padding() -> Result { - let length_size = size_of::(); - let data_align = align_of::(); - - // Calculate how many bytes we need to add to length_size - // to make it a multiple of data_align - let remainder = length_size - .checked_rem(data_align) - .ok_or(ProgramError::ArithmeticOverflow)?; - if remainder == 0 { - Ok(0) - } else { - data_align - .checked_sub(remainder) - .ok_or(ProgramError::ArithmeticOverflow) - } -} - -/// An API for interpreting a raw buffer (`&[u8]`) as a mutable, variable-length collection of Pod elements. -/// -/// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of -/// `Pod` data that resides in an external, pre-allocated `&mut [u8]` buffer. -/// It does not own the buffer itself, but acts as a mutable view over it. -/// -/// This is useful in environments where allocations are restricted or expensive, -/// such as Solana programs, allowing for efficient reads and manipulation of -/// dynamic-length data structures. -/// -/// ## Memory Layout -/// -/// The structure assumes the underlying byte buffer is formatted as follows: -/// 1. **Length**: A length field of type `L` at the beginning of the buffer, -/// indicating the number of currently active elements in the collection. Defaults to `PodU64` so the offset is then compatible with 1, 2, 4 and 8 bytes. -/// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. -/// 3. **Data**: The remaining part of the buffer, which is treated as a slice -/// of `T` elements. The capacity of the collection is the number of `T` -/// elements that can fit into this data portion. -pub struct ListView<'data, T: Pod, L: PodLength = PodU64> { - length: &'data mut L, - data: &'data mut [T], - max_length: usize, -} - -impl<'data, T: Pod, L: PodLength> ListView<'data, T, L> { - /// Unpack the mutable buffer into a mutable slice, with the option to - /// initialize the data - #[inline(always)] - fn unpack_internal(buf: &'data mut [u8], init: bool) -> Result { - // Split the buffer to get the length prefix. - // buf: [ L L L L | P P D D D D D D D D ...] - // <-------> <----------------------> - // len_bytes tail - let length_size = size_of::(); - if buf.len() < length_size { - return Err(PodSliceError::BufferTooSmall.into()); - } - let (len_bytes, tail) = buf.split_at_mut(length_size); - - // Skip alignment padding to find the start of the data. - // tail: [P P | D D D D D D D D ...] - // <-> <-------------------> - // padding data_bytes - let padding = calculate_padding::()?; - let data_bytes = tail - .get_mut(padding..) - .ok_or(PodSliceError::BufferTooSmall)?; - - // Cast the bytes to typed data - let length = pod_from_bytes_mut::(len_bytes)?; - let data = pod_slice_from_bytes_mut::(data_bytes)?; - let max_length = data.len(); - - // Initialize the list or validate its current length. - if init { - *length = L::try_from(0)?; - } else if (*length).into() > max_length { - return Err(PodSliceError::BufferTooSmall.into()); - } - - Ok(Self { - length, - data, - max_length, - }) - } - - /// Unpack the mutable buffer into a mutable slice - pub fn unpack<'a>(data: &'a mut [u8]) -> Result - where - 'a: 'data, - { - Self::unpack_internal(data, false) - } - - /// Unpack the mutable buffer into a mutable slice, and initialize the - /// slice to 0-length - pub fn init<'a>(data: &'a mut [u8]) -> Result - where - 'a: 'data, - { - Self::unpack_internal(data, true) - } - - /// Add another item to the slice - pub fn push(&mut self, item: T) -> Result<(), ProgramError> { - let length = (*self.length).into(); - if length >= self.max_length { - Err(PodSliceError::BufferTooSmall.into()) - } else { - self.data[length] = item; - *self.length = L::try_from(length.saturating_add(1))?; - Ok(()) - } - } - - /// Remove and return the element at `index`, shifting all later - /// elements one position to the left. - pub fn remove(&mut self, index: usize) -> Result { - let len = (*self.length).into(); - if index >= len { - return Err(ProgramError::InvalidArgument); - } - - let removed_item = self.data[index]; - - // Move the tail left by one - let tail_start = index - .checked_add(1) - .ok_or(ProgramError::ArithmeticOverflow)?; - self.data.copy_within(tail_start..len, index); - - // Zero-fill the now-unused slot at the end - let last = len.saturating_sub(1); - self.data[last] = T::zeroed(); - - // Store the new length (len - 1) - *self.length = L::try_from(last)?; - - Ok(removed_item) - } - - /// Get the amount of bytes used by `num_items` - pub fn size_of(num_items: usize) -> Result { - let padding_size = calculate_padding::()?; - let header_size = size_of::().saturating_add(padding_size); - - let data_size = size_of::() - .checked_mul(num_items) - .ok_or(PodSliceError::CalculationFailure)?; - - header_size - .checked_add(data_size) - .ok_or(PodSliceError::CalculationFailure.into()) - } - - /// Get the current number of items in collection - pub fn len(&self) -> usize { - (*self.length).into() - } - - /// Returns true if the collection is empty - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the current elements - pub fn iter(&self) -> std::slice::Iter { - let len = (*self.length).into(); - self.data[..len].iter() - } - - /// Returns a mutable iterator over the current elements - pub fn iter_mut(&mut self) -> std::slice::IterMut { - let len = (*self.length).into(); - self.data[..len].iter_mut() - } -} - -#[cfg(test)] -mod tests { - use { - super::*, - crate::primitives::{PodU16, PodU32, PodU64}, - bytemuck_derive::{Pod, Zeroable}, - }; - - #[repr(C)] - #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] - struct TestStruct { - test_field: u8, - test_pubkey: [u8; 32], - } - - #[test] - fn init_and_push() { - let size = ListView::::size_of(2).unwrap(); - let mut buffer = vec![0u8; size]; - - let mut pod_slice = ListView::::init(&mut buffer).unwrap(); - - pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU64::from(1)); - assert_eq!(pod_slice.len(), 1); - - pod_slice.push(TestStruct::default()).unwrap(); - assert_eq!(*pod_slice.length, PodU64::from(2)); - assert_eq!(pod_slice.len(), 2); - - // Buffer should be full now - let err = pod_slice.push(TestStruct::default()).unwrap_err(); - assert_eq!(err, PodSliceError::BufferTooSmall.into()); - } - - fn make_buffer + TryFrom>(capacity: usize, items: &[u8]) -> Vec - where - PodSliceError: From<>::Error>, - >::Error: std::fmt::Debug, - { - let length_size = size_of::(); - let padding_size = calculate_padding::().unwrap(); - let header_size = length_size.saturating_add(padding_size); - let buff_len = header_size.checked_add(capacity).unwrap(); - let mut buf = vec![0u8; buff_len]; - - // Write the length - let length = L::try_from(items.len()).unwrap(); - let length_bytes = bytemuck::bytes_of(&length); - buf[..length_size].copy_from_slice(length_bytes); - - // Copy the data after the header - let data_end = header_size.checked_add(items.len()).unwrap(); - buf[header_size..data_end].copy_from_slice(items); - buf - } - - #[test] - fn remove_at_first_item() { - let mut buff = make_buffer::(15, &[10, 20, 30, 40]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove(0).unwrap(); - assert_eq!(removed, 10); - assert_eq!(list_view.len(), 3); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[20, 30, 40]); - assert_eq!(list_view.data[3], 0); - } - - #[test] - fn remove_at_middle_item() { - let mut buff = make_buffer::(15, &[10, 20, 30, 40]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove(2).unwrap(); - assert_eq!(removed, 30); - assert_eq!(list_view.len(), 3); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 40]); - assert_eq!(list_view.data[3], 0); - } - - #[test] - fn remove_at_last_item() { - let mut buff = make_buffer::(15, &[10, 20, 30, 40]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove(3).unwrap(); - assert_eq!(removed, 40); - assert_eq!(list_view.len(), 3); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[10, 20, 30]); - assert_eq!(list_view.data[3], 0); - } - - #[test] - fn remove_at_out_of_bounds() { - let mut buff = make_buffer::(3, &[1, 2, 3]); - let original_buff = buff.clone(); - - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove(3).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - - // list_view should be unchanged - assert_eq!(list_view.len(), 3); - assert_eq!(list_view.data[..list_view.len()].to_vec(), vec![1, 2, 3]); - - assert_eq!(buff, original_buff); - } - - #[test] - fn remove_at_single_element() { - let mut buff = make_buffer::(1, &[10]); - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let removed = list_view.remove(0).unwrap(); - assert_eq!(removed, 10); - assert_eq!(list_view.len(), 0); - assert_eq!(list_view.data[..list_view.len()].to_vec(), &[] as &[u8]); - assert_eq!(list_view.data[0], 0); - } - - #[test] - fn remove_at_empty_slice() { - let mut buff = make_buffer::(0, &[]); - let original_buff = buff.clone(); - - let mut list_view = ListView::::unpack(&mut buff).unwrap(); - let err = list_view.remove(0).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - - // Assert list state is unchanged - assert_eq!(list_view.len(), 0); - - assert_eq!(buff, original_buff); - } - - #[test] - fn test_different_length_types() { - // Test with u16 length - let mut buff16 = make_buffer::(5, &[1, 2, 3]); - let list16 = ListView::::unpack(&mut buff16).unwrap(); - assert_eq!(list16.len(), 3); - assert_eq!(list16.len(), 3); - - // Test with u32 length - let mut buff32 = make_buffer::(5, &[4, 5, 6]); - let list32 = ListView::::unpack(&mut buff32).unwrap(); - assert_eq!(list32.len(), 3); - assert_eq!(list32.len(), 3); - - // Test with u64 length - let mut buff64 = make_buffer::(5, &[7, 8, 9]); - let list64 = ListView::::unpack(&mut buff64).unwrap(); - assert_eq!(list64.len(), 3); - assert_eq!(list64.len(), 3); - } - - #[test] - fn test_calculate_padding() { - // When length and data have same alignment, no padding needed - assert_eq!(calculate_padding::().unwrap(), 0); - assert_eq!(calculate_padding::().unwrap(), 0); - assert_eq!(calculate_padding::().unwrap(), 0); - - // When data alignment is smaller than or divides length size - assert_eq!(calculate_padding::().unwrap(), 0); - assert_eq!(calculate_padding::().unwrap(), 0); - assert_eq!(calculate_padding::().unwrap(), 0); - assert_eq!(calculate_padding::().unwrap(), 0); - assert_eq!(calculate_padding::().unwrap(), 0); - - // When padding is needed - assert_eq!(calculate_padding::().unwrap(), 2); // 2 + 2 = 4 (align to 4) - assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 (align to 8) - assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 (align to 8) - - // Test with custom aligned structs - #[repr(C, align(8))] - #[derive(Pod, Zeroable, Copy, Clone)] - struct Align8 { - _data: [u8; 8], - } - - #[repr(C, align(16))] - #[derive(Pod, Zeroable, Copy, Clone)] - struct Align16 { - _data: [u8; 16], - } - - assert_eq!(calculate_padding::().unwrap(), 6); // 2 + 6 = 8 - assert_eq!(calculate_padding::().unwrap(), 4); // 4 + 4 = 8 - assert_eq!(calculate_padding::().unwrap(), 0); // 8 % 8 = 0 - - assert_eq!(calculate_padding::().unwrap(), 14); // 2 + 14 = 16 - assert_eq!(calculate_padding::().unwrap(), 12); // 4 + 12 = 16 - assert_eq!(calculate_padding::().unwrap(), 8); // 8 + 8 = 16 - } - - #[test] - fn test_alignment_in_practice() { - // u32 length with u64 data - needs 4 bytes padding - let size = ListView::::size_of(2).unwrap(); - let mut buffer = vec![0u8; size]; - let list = ListView::::init(&mut buffer).unwrap(); - - // Check that data pointer is 8-byte aligned - let data_ptr = list.data.as_ptr() as usize; - assert_eq!(data_ptr % 8, 0); - - // u16 length with u64 data - needs 6 bytes padding - let size = ListView::::size_of(2).unwrap(); - let mut buffer = vec![0u8; size]; - let list = ListView::::init(&mut buffer).unwrap(); - - let data_ptr = list.data.as_ptr() as usize; - assert_eq!(data_ptr % 8, 0); - } - - #[test] - fn test_length_too_large() { - // Create a buffer with capacity for 2 items - let capacity = 2; - let length_size = size_of::(); - let padding_size = calculate_padding::().unwrap(); - let header_size = length_size.saturating_add(padding_size); - let buff_len = header_size.checked_add(capacity).unwrap(); - let mut buffer = vec![0u8; buff_len]; - - // Manually write a length value that exceeds the capacity - let invalid_length = PodU32::try_from(capacity + 1).unwrap(); - let length_bytes = bytemuck::bytes_of(&invalid_length); - buffer[..length_size].copy_from_slice(length_bytes); - - // Attempting to unpack should return BufferTooSmall error - match ListView::::unpack(&mut buffer) { - Err(err) => assert_eq!(err, PodSliceError::BufferTooSmall.into()), - Ok(_) => panic!("Expected BufferTooSmall error, but unpack succeeded"), - } - } -} diff --git a/pod/src/list/list_view.rs b/pod/src/list/list_view.rs new file mode 100644 index 00000000..53be0fdc --- /dev/null +++ b/pod/src/list/list_view.rs @@ -0,0 +1,584 @@ +//! `ListView`, a compact, zero-copy array wrapper. + +use { + crate::{ + bytemuck::{ + pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, + }, + error::PodSliceError, + list::{list_view_mut::ListViewMut, list_view_read_only::ListViewReadOnly}, + pod_length::PodLength, + primitives::PodU64, + }, + bytemuck::Pod, + solana_program_error::ProgramError, + std::{ + marker::PhantomData, + mem::{align_of, size_of}, + ops::Range, + }, +}; + +/// An API for interpreting a raw buffer (`&[u8]`) as a variable-length collection of Pod elements. +/// +/// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `Pod` data that resides in an external, pre-allocated `&[u8]` buffer. +/// It does not own the buffer itself, but acts as a view over it, which can be +/// read-only (`ListViewReadOnly`) or mutable (`ListViewMut`). +/// +/// This is useful in environments where allocations are restricted or expensive, +/// such as Solana programs, allowing for efficient reads and manipulation of +/// dynamic-length data structures. +/// +/// ## Memory Layout +/// +/// The structure assumes the underlying byte buffer is formatted as follows: +/// 1. **Length**: A length field of type `L` at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. +/// Defaults to `PodU64` so the offset is then compatible with 1, 2, 4 and 8 bytes. +/// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. +/// 3. **Data**: The remaining part of the buffer, which is treated as a slice +/// of `T` elements. The capacity of the collection is the number of `T` +/// elements that can fit into this data portion. +pub struct ListView(PhantomData<(T, L)>); + +struct Layout { + length_range: Range, + data_range: Range, +} + +impl ListView { + /// Calculate the total byte size for a `ListView` holding `num_items`. + /// This includes the length prefix, padding, and data. + pub fn size_of(num_items: usize) -> Result { + size_of::() + .checked_mul(num_items) + .and_then(|curr| curr.checked_add(size_of::())) + .and_then(|curr| curr.checked_add(Self::header_padding())) + .ok_or_else(|| PodSliceError::CalculationFailure.into()) + } + + /// Unpack a read-only buffer into a `ListViewReadOnly` + pub fn unpack(buf: &[u8]) -> Result, ProgramError> { + let layout = Self::calculate_layout(buf.len())?; + + // Slice the buffer to get the length prefix and the data. + // The layout calculation provides the correct ranges, accounting for any + // padding between the length and the data. + // + // buf: [ L L L L | P P | D D D D D D D D ...] + // <-----> <------------------> + // len_bytes data_bytes + let len_bytes = &buf[layout.length_range]; + let data_bytes = &buf[layout.data_range]; + + let length = pod_from_bytes::(len_bytes)?; + let data = pod_slice_from_bytes::(data_bytes)?; + let capacity = data.len(); + + if (*length).into() > capacity { + return Err(PodSliceError::BufferTooSmall.into()); + } + + Ok(ListViewReadOnly { + length, + data, + capacity, + }) + } + + /// Unpack the mutable buffer into a mutable `ListViewMut` + pub fn unpack_mut(buf: &mut [u8]) -> Result, ProgramError> { + let view = Self::build_mut_view(buf)?; + if (*view.length).into() > view.capacity { + return Err(PodSliceError::BufferTooSmall.into()); + } + Ok(view) + } + + /// Initialize a buffer: sets `length = 0` and returns a mutable `ListViewMut`. + pub fn init(buf: &mut [u8]) -> Result, ProgramError> { + let view = Self::build_mut_view(buf)?; + *view.length = L::try_from(0)?; + Ok(view) + } + + /// Internal helper to build a mutable view without validation or initialization. + #[inline] + fn build_mut_view(buf: &mut [u8]) -> Result, ProgramError> { + let layout = Self::calculate_layout(buf.len())?; + + // Split the buffer to get the length prefix and the data. + // buf: [ L L L L | P P | D D D D D D D D ...] + // <---- head ---> <--- tail ---------> + let (header_bytes, data_bytes) = buf.split_at_mut(layout.data_range.start); + // header: [ L L L L | P P ] + // <-----> + // len_bytes + let len_bytes = &mut header_bytes[layout.length_range]; + + // Cast the bytes to typed data + let length = pod_from_bytes_mut::(len_bytes)?; + let data = pod_slice_from_bytes_mut::(data_bytes)?; + let capacity = data.len(); + + Ok(ListViewMut { + length, + data, + capacity, + }) + } + + /// Calculate the byte ranges for the length and data sections of the buffer + #[inline] + fn calculate_layout(buf_len: usize) -> Result { + let len_field_end = size_of::(); + let data_start = len_field_end.saturating_add(Self::header_padding()); + + if buf_len < data_start { + return Err(PodSliceError::BufferTooSmall.into()); + } + + Ok(Layout { + length_range: 0..len_field_end, + data_range: data_start..buf_len, + }) + } + + /// Calculate the padding required to align the data part of the buffer. + /// + /// The goal is to ensure that the data field `T` starts at a memory offset + /// that is a multiple of its alignment requirement. + #[inline] + const fn header_padding() -> usize { + let length_size = size_of::(); + let data_align = align_of::(); + + // No padding is needed for alignments of 0 or 1 + if data_align == 0 || data_align == 1 { + return 0; + } + + // Find how many bytes `length_size` extends past an alignment boundary + #[allow(clippy::arithmetic_side_effects)] + let remainder = length_size.wrapping_rem(data_align); + + // If already aligned (remainder is 0), no padding is needed. + // Otherwise, calculate the distance to the next alignment boundary. + if remainder == 0 { + 0 + } else { + data_align.wrapping_sub(remainder) + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::ListViewable, + primitives::{PodU16, PodU32}, + }, + bytemuck_derive::{Pod as DerivePod, Zeroable}, + }; + + #[test] + fn test_size_of_no_padding() { + // Case 1: T has align 1, so no padding is ever needed. + // 10 items * 1 byte/item + 4 bytes for length = 14 + assert_eq!(ListView::::size_of(10).unwrap(), 14); + + // Case 2: size_of is a multiple of align_of, so no padding needed. + // T = u32 (size 4, align 4), L = PodU64 (size 8). 8 % 4 == 0. + // 10 items * 4 bytes/item + 8 bytes for length = 48 + assert_eq!(ListView::::size_of(10).unwrap(), 48); + + // Case 3: 0 items. Size should just be size_of + padding. + // Padding is 0 here. + // 0 items * 4 bytes/item + 8 bytes for length = 8 + assert_eq!(ListView::::size_of(0).unwrap(), 8); + } + + #[test] + fn test_size_of_with_padding() { + // Case 1: Padding is required. + // T = u64 (size 8, align 8), L = PodU32 (size 4). + // Padding required to align data to 8 bytes is 4. (4 + 4 = 8) + // (10 items * 8 bytes/item) + 4 bytes for length + 4 bytes for padding = 88 + assert_eq!(ListView::::size_of(10).unwrap(), 88); + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align16(u128); + + // Case 2: Custom struct with high alignment. + // size 16, align 16 + // L = PodU64 (size 8). + // Padding required to align data to 16 bytes is 8. (8 + 8 = 16) + // (10 items * 16 bytes/item) + 8 bytes for length + 8 bytes for padding = 176 + assert_eq!(ListView::::size_of(10).unwrap(), 176); + + // Case 3: 0 items with padding. + // Size should be size_of + padding. + // L = PodU32 (size 4), T = u64 (align 8). Padding is 4. + // Total size = 4 + 4 = 8 + assert_eq!(ListView::::size_of(0).unwrap(), 8); + } + + #[test] + fn test_size_of_overflow() { + // Case 1: Multiplication overflows. + // `size_of::() * usize::MAX` will overflow. + let err = ListView::::size_of(usize::MAX).unwrap_err(); + assert_eq!(err, PodSliceError::CalculationFailure.into()); + + // Case 2: Multiplication does not overflow, but subsequent addition does. + // `size_of::() * usize::MAX` does not overflow, but adding `size_of` will. + let err = ListView::::size_of(usize::MAX).unwrap_err(); + assert_eq!(err, PodSliceError::CalculationFailure.into()); + } + + #[test] + fn test_padding_calculation() { + // `u8` has an alignment of 1, so no padding is ever needed. + assert_eq!(ListView::::header_padding(), 0); + + // Zero-Sized Types like `()` have size 0 and align 1, requiring no padding. + assert_eq!(ListView::<(), PodU64>::header_padding(), 0); + + // When length and data have the same alignment. + assert_eq!(ListView::::header_padding(), 0); + assert_eq!(ListView::::header_padding(), 0); + assert_eq!(ListView::::header_padding(), 0); + + // When data alignment is smaller than or perfectly divides the length size. + assert_eq!(ListView::::header_padding(), 0); // 8 % 2 = 0 + assert_eq!(ListView::::header_padding(), 0); // 8 % 4 = 0 + + // When padding IS needed. + assert_eq!(ListView::::header_padding(), 2); // size_of is 2. To align to 4, need 2 bytes. + assert_eq!(ListView::::header_padding(), 6); // size_of is 2. To align to 8, need 6 bytes. + assert_eq!(ListView::::header_padding(), 4); // size_of is 4. To align to 8, need 4 bytes. + + // Test with custom, higher alignments. + #[repr(C, align(8))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align8(u64); + + // Test against different length types + assert_eq!(ListView::::header_padding(), 6); // 2 + 6 = 8 + assert_eq!(ListView::::header_padding(), 4); // 4 + 4 = 8 + assert_eq!(ListView::::header_padding(), 0); // 8 is already aligned + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align16(u128); + + assert_eq!(ListView::::header_padding(), 14); // 2 + 14 = 16 + assert_eq!(ListView::::header_padding(), 12); // 4 + 12 = 16 + assert_eq!(ListView::::header_padding(), 8); // 8 + 8 = 16 + } + + #[test] + fn test_unpack_success_no_padding() { + // T = u32 (align 4), L = PodU32 (size 4, align 4). No padding needed. + let length: u32 = 2; + let capacity: usize = 3; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size; + let items = [100u32, 200u32]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_success_with_padding() { + // T = u64 (align 8), L = PodU32 (size 4, align 4). Needs 4 bytes padding. + let padding = ListView::::header_padding(); + assert_eq!(padding, 4); + + let length: u32 = 2; + let capacity: usize = 2; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + padding + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + // Data starts after length and padding + let data_start = len_size + padding; + let items = [100u64, 200u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_success_zero_length() { + let capacity: usize = 5; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = 0u32.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), 0); + assert_eq!(view_ro.capacity(), capacity); + assert!(view_ro.is_empty()); + assert_eq!(view_ro.as_slice(), &[] as &[u32]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), 0); + assert_eq!(view_mut.capacity(), capacity); + assert!(view_mut.is_empty()); + assert_eq!(view_mut.as_slice(), &[] as &[u32]); + } + + #[test] + fn test_unpack_success_full_capacity() { + let length: u64 = 3; + let capacity: usize = 3; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU64 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size; + let items = [1u64, 2u64, 3u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_success_different_length_type() { + // T = u64 (align 8), L = PodU16 (size 2, align 2). Needs 6 bytes padding. + let padding = ListView::::header_padding(); + assert_eq!(padding, 6); + + let length: u16 = 1; + let capacity: usize = 1; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + padding + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU16 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size + padding; + let items = [12345u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(view_ro.as_slice(), &items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(view_mut.as_slice(), &items[..]); + } + + #[test] + fn test_unpack_fail_buffer_too_small_for_header() { + // T = u64 (align 8), L = PodU32 (size 4). Header size is 8. + let header_size = ListView::::size_of(0).unwrap(); + assert_eq!(header_size, 8); + + // Provide a buffer smaller than the required header + let mut buf = vec![0u8; header_size - 1]; // 7 bytes + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_unpack_fail_declared_length_exceeds_capacity() { + let declared_length: u32 = 4; + let capacity: usize = 3; // buffer can only hold 3 + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + // Write a length that is bigger than capacity + let pod_len: PodU32 = declared_length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_unpack_fail_data_part_not_multiple_of_item_size() { + let len_size = size_of::(); + + // data part is 5 bytes, not a multiple of item_size (4) + let buf_size = len_size + 5; + let mut buf = vec![0u8; buf_size]; + + // bytemuck::try_cast_slice returns an alignment error, which we map to InvalidArgument + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_unpack_empty_buffer() { + let mut buf = []; + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_init_success_no_padding() { + // T = u32 (align 4), L = PodU32 (size 4). No padding needed. + let capacity: usize = 5; + let len_size = size_of::(); + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } + + #[test] + fn test_init_success_with_padding() { + // T = u64 (align 8), L = PodU32 (size 4). Needs 4 bytes padding. + let capacity: usize = 3; + let len_size = size_of::(); + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + // The padding bytes may or may not be zeroed, we don't assert on them. + } + + #[test] + fn test_init_success_zero_capacity() { + // Test initializing a buffer that can only hold the header. + // T = u64 (align 8), L = PodU32 (size 4). Header size is 8. + let buf_size = ListView::::size_of(0).unwrap(); + assert_eq!(buf_size, 8); + let mut buf = vec![0xFFu8; buf_size]; + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let len_size = size_of::(); + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } + + #[test] + fn test_init_fail_buffer_too_small() { + // Header requires 4 bytes (size_of) + let mut buf = vec![0u8; 3]; + let err = ListView::::init(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + // With padding, header requires 8 bytes (4 for len, 4 for pad) + let mut buf_padded = vec![0u8; 7]; + let err_padded = ListView::::init(&mut buf_padded).unwrap_err(); + assert_eq!(err_padded, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_init_success_default_length_type() { + // This test uses the default L=PodU64 length type by omitting it. + // T = u32 (align 4), L = PodU64 (size 8). No padding needed as 8 % 4 == 0. + let capacity = 5; + let len_size = size_of::(); // Default L is PodU64 + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length (a u64) was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 8]); + } +} diff --git a/pod/src/list/list_view_mut.rs b/pod/src/list/list_view_mut.rs new file mode 100644 index 00000000..8980ea32 --- /dev/null +++ b/pod/src/list/list_view_mut.rs @@ -0,0 +1,306 @@ +//! `ListViewMut`, a mutable, compact, zero-copy array wrapper. + +use { + crate::{ + error::PodSliceError, list::list_viewable::ListViewable, pod_length::PodLength, + primitives::PodU64, + }, + bytemuck::Pod, + solana_program_error::ProgramError, +}; + +#[derive(Debug)] +pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU64> { + pub(crate) length: &'data mut L, + pub(crate) data: &'data mut [T], + pub(crate) capacity: usize, +} + +impl ListViewMut<'_, T, L> { + /// Add another item to the slice + pub fn push(&mut self, item: T) -> Result<(), ProgramError> { + let length = (*self.length).into(); + if length >= self.capacity { + Err(PodSliceError::BufferTooSmall.into()) + } else { + self.data[length] = item; + *self.length = L::try_from(length.saturating_add(1))?; + Ok(()) + } + } + + /// Remove and return the element at `index`, shifting all later + /// elements one position to the left. + pub fn remove(&mut self, index: usize) -> Result { + let len = (*self.length).into(); + if index >= len { + return Err(ProgramError::InvalidArgument); + } + + let removed_item = self.data[index]; + + // Move the tail left by one + let tail_start = index + .checked_add(1) + .ok_or(ProgramError::ArithmeticOverflow)?; + self.data.copy_within(tail_start..len, index); + + // Store the new length (len - 1) + let last = len.saturating_sub(1); + *self.length = L::try_from(last)?; + + Ok(removed_item) + } + + /// Returns a mutable iterator over the current elements + pub fn iter_mut(&mut self) -> std::slice::IterMut { + let len = (*self.length).into(); + self.data[..len].iter_mut() + } +} + +impl ListViewable for ListViewMut<'_, T, L> { + type Item = T; + + fn len(&self) -> usize { + (*self.length).into() + } + + fn capacity(&self) -> usize { + self.capacity + } + + fn as_slice(&self) -> &[Self::Item] { + &self.data[..self.len()] + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::{ListView, ListViewable}, + primitives::{PodU32, PodU64}, + }, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestStruct { + a: u64, + b: u32, + _padding: [u8; 4], + } + + impl TestStruct { + fn new(a: u64, b: u32) -> Self { + Self { + a, + b, + _padding: [0; 4], + } + } + } + + fn init_view_mut( + buffer: &mut Vec, + capacity: usize, + ) -> ListViewMut { + let size = ListView::::size_of(capacity).unwrap(); + buffer.resize(size, 0); + ListView::::init(buffer).unwrap() + } + + #[test] + fn test_push() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + assert_eq!(view.len(), 0); + assert!(view.is_empty()); + assert_eq!(view.capacity(), 3); + + // Push first item + let item1 = TestStruct::new(1, 10); + view.push(item1).unwrap(); + assert_eq!(view.len(), 1); + assert!(!view.is_empty()); + assert_eq!(view.as_slice(), &[item1]); + + // Push second item + let item2 = TestStruct::new(2, 20); + view.push(item2).unwrap(); + assert_eq!(view.len(), 2); + assert_eq!(view.as_slice(), &[item1, item2]); + + // Push third item to fill capacity + let item3 = TestStruct::new(3, 30); + view.push(item3).unwrap(); + assert_eq!(view.len(), 3); + assert_eq!(view.as_slice(), &[item1, item2, item3]); + + // Try to push beyond capacity + let item4 = TestStruct::new(4, 40); + let err = view.push(item4).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + // Ensure state is unchanged + assert_eq!(view.len(), 3); + assert_eq!(view.as_slice(), &[item1, item2, item3]); + } + + #[test] + fn test_remove() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 4); + + let item1 = TestStruct::new(1, 10); + let item2 = TestStruct::new(2, 20); + let item3 = TestStruct::new(3, 30); + let item4 = TestStruct::new(4, 40); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + view.push(item4).unwrap(); + + assert_eq!(view.len(), 4); + assert_eq!(view.as_slice(), &[item1, item2, item3, item4]); + + // Remove from the middle + let removed = view.remove(1).unwrap(); + assert_eq!(removed, item2); + assert_eq!(view.len(), 3); + assert_eq!(view.as_slice(), &[item1, item3, item4]); + + // Remove from the end + let removed = view.remove(2).unwrap(); + assert_eq!(removed, item4); + assert_eq!(view.len(), 2); + assert_eq!(view.as_slice(), &[item1, item3]); + + // Remove from the start + let removed = view.remove(0).unwrap(); + assert_eq!(removed, item1); + assert_eq!(view.len(), 1); + assert_eq!(view.as_slice(), &[item3]); + + // Remove the last element + let removed = view.remove(0).unwrap(); + assert_eq!(removed, item3); + assert_eq!(view.len(), 0); + assert!(view.is_empty()); + assert_eq!(view.as_slice(), &[]); + } + + #[test] + fn test_remove_out_of_bounds() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 2); + + view.push(TestStruct::new(1, 10)).unwrap(); + view.push(TestStruct::new(2, 20)).unwrap(); + + // Try to remove at index == len + let err = view.remove(2).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + assert_eq!(view.len(), 2); // Unchanged + + // Try to remove at index > len + let err = view.remove(100).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + assert_eq!(view.len(), 2); // Unchanged + + // Empty the view + view.remove(1).unwrap(); + view.remove(0).unwrap(); + assert!(view.is_empty()); + + // Try to remove from empty view + let err = view.remove(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_iter_mut() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 4); + + let item1 = TestStruct::new(1, 10); + let item2 = TestStruct::new(2, 20); + let item3 = TestStruct::new(3, 30); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + + assert_eq!(view.len(), 3); + assert_eq!(view.capacity(), 4); + + // Modify items using iter_mut + for item in view.iter_mut() { + item.a *= 10; + } + + let expected_item1 = TestStruct::new(10, 10); + let expected_item2 = TestStruct::new(20, 20); + let expected_item3 = TestStruct::new(30, 30); + + // Check that the underlying data is modified + assert_eq!(view.len(), 3); + assert_eq!( + view.as_slice(), + &[expected_item1, expected_item2, expected_item3] + ); + + // Check that iter_mut only iterates over `len` items, not `capacity` + assert_eq!(view.iter_mut().count(), 3); + } + + #[test] + fn test_iter_mut_empty() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 5); + + let mut count = 0; + for _ in view.iter_mut() { + count += 1; + } + assert_eq!(count, 0); + assert_eq!(view.iter_mut().next(), None); + } + + #[test] + fn test_zero_capacity() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 0); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + + let err = view.push(TestStruct::new(1, 1)).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = view.remove(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_default_length_type() { + let capacity = 2; + let mut buffer = vec![]; + let size = ListView::::size_of(capacity).unwrap(); + buffer.resize(size, 0); + + // Initialize the view *without* specifying L. The compiler uses the default. + let view = ListView::::init(&mut buffer).unwrap(); + + // Check that the capacity is correct for a PodU64 length. + assert_eq!(view.capacity(), capacity); + assert_eq!(view.len(), 0); + + // Verify the size of the length field. + assert_eq!(size_of_val(view.length), size_of::()); + } +} diff --git a/pod/src/list/list_view_read_only.rs b/pod/src/list/list_view_read_only.rs new file mode 100644 index 00000000..6fce0bde --- /dev/null +++ b/pod/src/list/list_view_read_only.rs @@ -0,0 +1,156 @@ +//! `ListViewReadOnly`, a read-only, compact, zero-copy array wrapper. + +use { + crate::{list::list_viewable::ListViewable, pod_length::PodLength, primitives::PodU64}, + bytemuck::Pod, +}; + +#[derive(Debug)] +pub struct ListViewReadOnly<'data, T: Pod, L: PodLength = PodU64> { + pub(crate) length: &'data L, + pub(crate) data: &'data [T], + pub(crate) capacity: usize, +} + +impl ListViewable for ListViewReadOnly<'_, T, L> { + type Item = T; + + fn len(&self) -> usize { + (*self.length).into() + } + + fn capacity(&self) -> usize { + self.capacity + } + + fn as_slice(&self) -> &[Self::Item] { + &self.data[..self.len()] + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{list::ListView, pod_length::PodLength, primitives::PodU32}, + bytemuck_derive::{Pod as DerivePod, Zeroable}, + std::mem::size_of, + }; + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone, Debug, PartialEq)] + struct TestStruct(u128); + + /// Helper to build a byte buffer that conforms to the `ListView` layout. + fn build_test_buffer( + length: usize, + capacity: usize, + items: &[T], + ) -> Vec { + let size = ListView::::size_of(capacity).unwrap(); + let mut buffer = vec![0u8; size]; + + // Write the length prefix + let pod_len = L::try_from(length).unwrap(); + let len_bytes = bytemuck::bytes_of(&pod_len); + buffer[0..size_of::()].copy_from_slice(len_bytes); + + // Write the data items, accounting for padding + if !items.is_empty() { + let data_start = ListView::::size_of(0).unwrap(); + let items_bytes = bytemuck::cast_slice(items); + buffer[data_start..data_start.saturating_add(items_bytes.len())] + .copy_from_slice(items_bytes); + } + + buffer + } + + #[test] + fn test_len_and_capacity() { + let items = [10u32, 20, 30]; + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.len(), 3); + assert_eq!(view.capacity(), 5); + } + + #[test] + fn test_as_slice() { + let items = [10u32, 20, 30]; + // Buffer has capacity for 5, but we only use 3. + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // `as_slice()` should only return the first `len` items. + assert_eq!(view.as_slice(), &items[..]); + assert_eq!(view.as_slice().len(), view.len()); + } + + #[test] + fn test_is_empty() { + // Not empty + let buffer_full = build_test_buffer::(1, 2, &[10]); + let view_full = ListView::::unpack(&buffer_full).unwrap(); + assert!(!view_full.is_empty()); + + // Empty + let buffer_empty = build_test_buffer::(0, 2, &[]); + let view_empty = ListView::::unpack(&buffer_empty).unwrap(); + assert!(view_empty.is_empty()); + } + + #[test] + fn test_iter() { + let items = [TestStruct(1), TestStruct(2)]; + let buffer = build_test_buffer::(items.len(), 3, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + let mut iter = view.iter(); + assert_eq!(iter.next(), Some(&items[0])); + assert_eq!(iter.next(), Some(&items[1])); + assert_eq!(iter.next(), None); + let collected: Vec<_> = view.iter().collect(); + assert_eq!(collected, vec![&items[0], &items[1]]); + } + + #[test] + fn test_iter_on_empty_list() { + let buffer = build_test_buffer::(0, 5, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.iter().count(), 0); + assert_eq!(view.iter().next(), None); + } + + #[test] + fn test_zero_capacity() { + // Buffer is just big enough for the header (len + padding), no data. + let buffer = build_test_buffer::(0, 0, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + assert_eq!(view.as_slice(), &[]); + } + + #[test] + fn test_with_padding() { + // Test the effect of padding by checking the total header size. + // T=AlignedStruct (align 16), L=PodU32 (size 4). + // The header size should be 16 (4 for len + 12 for padding). + let header_size = ListView::::size_of(0).unwrap(); + assert_eq!(header_size, 16); + + let items = [TestStruct(123), TestStruct(456)]; + let buffer = build_test_buffer::(items.len(), 4, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // Check if the public API works as expected despite internal padding + assert_eq!(view.len(), 2); + assert_eq!(view.capacity(), 4); + assert_eq!(view.as_slice(), &items[..]); + } +} diff --git a/pod/src/list/list_viewable.rs b/pod/src/list/list_viewable.rs new file mode 100644 index 00000000..b3144a45 --- /dev/null +++ b/pod/src/list/list_viewable.rs @@ -0,0 +1,27 @@ +use {bytemuck::Pod, std::slice::Iter}; + +/// A trait to abstract the shared, read-only behavior +/// between `ListViewReadOnly` and `ListViewMut`. +pub trait ListViewable { + /// The type of the items stored in the list. + type Item: Pod; + + /// Returns the number of items in the list. + fn len(&self) -> usize; + + /// Returns `true` if the list contains no items. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the total number of items that can be stored in the list. + fn capacity(&self) -> usize; + + /// Returns a read-only slice of the items currently in the list. + fn as_slice(&self) -> &[Self::Item]; + + /// Returns a read-only iterator over the list. + fn iter(&self) -> Iter<'_, Self::Item> { + self.as_slice().iter() + } +} diff --git a/pod/src/list/mod.rs b/pod/src/list/mod.rs new file mode 100644 index 00000000..47463684 --- /dev/null +++ b/pod/src/list/mod.rs @@ -0,0 +1,9 @@ +mod list_view; +mod list_view_mut; +mod list_view_read_only; +mod list_viewable; + +pub use { + list_view::ListView, list_view_mut::ListViewMut, list_view_read_only::ListViewReadOnly, + list_viewable::ListViewable, +}; diff --git a/pod/src/pod_length.rs b/pod/src/pod_length.rs index 52e4a7eb..bd418dc9 100644 --- a/pod/src/pod_length.rs +++ b/pod/src/pod_length.rs @@ -9,46 +9,31 @@ use { /// Marker trait for converting to/from Pod `uint`'s and `usize` pub trait PodLength: Pod + Into + TryFrom {} +/// Blanket implementation to automatically implement `PodLength` for any type +/// that satisfies the required bounds. impl PodLength for T where T: Pod + Into + TryFrom {} -impl TryFrom for PodU16 { - type Error = PodSliceError; - - fn try_from(val: usize) -> Result { - Ok(u16::try_from(val)?.into()) - } -} - -impl From for usize { - fn from(pod: PodU16) -> Self { - u16::from(pod) as usize - } -} - -impl TryFrom for PodU32 { - type Error = PodSliceError; - - fn try_from(val: usize) -> Result { - Ok(u32::try_from(val)?.into()) - } +/// Implements the `TryFrom` and `From for usize` conversions for a Pod integer type +macro_rules! impl_pod_length_for { + ($PodType:ty, $PrimitiveType:ty) => { + impl TryFrom for $PodType { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + let primitive_val = <$PrimitiveType>::try_from(val)?; + Ok(primitive_val.into()) + } + } + + impl From<$PodType> for usize { + fn from(pod_val: $PodType) -> Self { + let primitive_val = <$PrimitiveType>::from(pod_val); + primitive_val as usize + } + } + }; } -impl From for usize { - fn from(pod: PodU32) -> Self { - u32::from(pod) as usize - } -} - -impl TryFrom for PodU64 { - type Error = PodSliceError; - - fn try_from(val: usize) -> Result { - Ok(u64::try_from(val)?.into()) - } -} - -impl From for usize { - fn from(pod: PodU64) -> Self { - u64::from(pod) as usize - } -} +impl_pod_length_for!(PodU16, u16); +impl_pod_length_for!(PodU32, u32); +impl_pod_length_for!(PodU64, u64); diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 34433466..8f138d35 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -2,49 +2,45 @@ use { crate::{ - bytemuck::{pod_from_bytes, pod_slice_from_bytes}, - error::PodSliceError, - list::ListView, + list::{ListView, ListViewMut, ListViewReadOnly, ListViewable}, primitives::PodU32, }, bytemuck::Pod, solana_program_error::ProgramError, }; -const LENGTH_SIZE: usize = std::mem::size_of::(); +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). \ + Please use `ListView` instead. If using with existing data initialized by PodSlice, \ + you need to specify PodU32 length (e.g. ListView::::unpack(bytes))" +)] /// Special type for using a slice of `Pod`s in a zero-copy way +#[allow(deprecated)] pub struct PodSlice<'data, T: Pod> { - length: &'data PodU32, - data: &'data [T], + inner: ListViewReadOnly<'data, T, PodU32>, } + +#[allow(deprecated)] impl<'data, T: Pod> PodSlice<'data, T> { /// Unpack the buffer into a slice pub fn unpack<'a>(data: &'a [u8]) -> Result where 'a: 'data, { - if data.len() < LENGTH_SIZE { - return Err(PodSliceError::BufferTooSmall.into()); - } - let (length, data) = data.split_at(LENGTH_SIZE); - let length = pod_from_bytes::(length)?; - let _max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; - let data = pod_slice_from_bytes(data)?; - Ok(Self { length, data }) + let inner = ListView::::unpack(data)?; + Ok(Self { inner }) } /// Get the slice data pub fn data(&self) -> &[T] { - let length = u32::from(*self.length) as usize; - &self.data[..length] + let len = self.inner.len(); + &self.inner.data[..len] } /// Get the amount of bytes used by `num_items` pub fn size_of(num_items: usize) -> Result { - std::mem::size_of::() - .checked_mul(num_items) - .and_then(|len| len.checked_add(LENGTH_SIZE)) - .ok_or_else(|| PodSliceError::CalculationFailure.into()) + ListView::::size_of(num_items) } } @@ -52,12 +48,12 @@ impl<'data, T: Pod> PodSlice<'data, T> { since = "0.6.0", note = "This struct will be removed in the next major release (1.0.0). \ Please use `ListView` instead. If using with existing data initialized by PodSliceMut, \ - you need to specifiy PodU32 length (e.g. ListView::::init(bytes))" + you need to specify PodU32 length (e.g. ListView::::init(bytes))" )] /// Special type for using a slice of mutable `Pod`s in a zero-copy way. /// Uses `ListView` under the hood. pub struct PodSliceMut<'data, T: Pod> { - inner: ListView<'data, T, PodU32>, + inner: ListViewMut<'data, T, PodU32>, } #[allow(deprecated)] @@ -67,7 +63,7 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { where 'a: 'data, { - let inner = ListView::::unpack(data)?; + let inner = ListView::::unpack_mut(data)?; Ok(Self { inner }) } @@ -87,38 +83,12 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { } } -fn max_len_for_type(data_len: usize, length_val: usize) -> Result { - let item_size = std::mem::size_of::(); - let max_len = data_len - .checked_div(item_size) - .ok_or(PodSliceError::CalculationFailure)?; - - // Make sure the max length that can be stored in the buffer isn't less - // than the length value. - if max_len < length_val { - Err(PodSliceError::BufferTooSmall)? - } - - // Make sure the buffer is cleanly divisible by `size_of::`; not over or - // under allocated. - if max_len.saturating_mul(item_size) != data_len { - if max_len == 0 { - // Size of T is greater than buffer size - Err(PodSliceError::BufferTooSmall)? - } else { - Err(PodSliceError::BufferTooLarge)? - } - } - - Ok(max_len) -} - #[cfg(test)] #[allow(deprecated)] mod tests { use { super::*, - crate::bytemuck::pod_slice_to_bytes, + crate::{bytemuck::pod_slice_to_bytes, error::PodSliceError}, bytemuck_derive::{Pod, Zeroable}, }; @@ -129,6 +99,8 @@ mod tests { test_pubkey: [u8; 32], } + const LENGTH_SIZE: usize = std::mem::size_of::(); + #[test] fn test_pod_slice() { let test_field_bytes = [0]; @@ -149,7 +121,7 @@ mod tests { let pod_slice = PodSlice::::unpack(&pod_slice_bytes).unwrap(); let pod_slice_data = pod_slice.data(); - assert_eq!(*pod_slice.length, PodU32::from(2)); + assert_eq!(pod_slice.inner.len(), 2); assert_eq!(pod_slice_to_bytes(pod_slice.data()), data_bytes); assert_eq!(pod_slice_data[0].test_field, test_field_bytes[0]); assert_eq!(pod_slice_data[0].test_pubkey, test_pubkey_bytes); @@ -166,11 +138,7 @@ mod tests { let err = PodSlice::::unpack(&pod_slice_bytes) .err() .unwrap(); - assert_eq!( - err, - PodSliceError::BufferTooLarge.into(), - "Expected an `PodSliceError::BufferTooLarge` error" - ); + assert!(matches!(err, ProgramError::InvalidArgument)); } #[test] @@ -190,7 +158,7 @@ mod tests { data[..LENGTH_SIZE].copy_from_slice(&length_le); let pod_slice = PodSlice::::unpack(&data).unwrap(); - let pod_slice_len = u32::from(*pod_slice.length); + let pod_slice_len = pod_slice.inner.len() as u32; let data = pod_slice.data(); let data_vec = data.to_vec(); @@ -207,11 +175,7 @@ mod tests { let err = PodSlice::::unpack(&pod_slice_bytes) .err() .unwrap(); - assert_eq!( - err, - PodSliceError::BufferTooSmall.into(), - "Expected an `PodSliceError::BufferTooSmall` error" - ); + assert!(matches!(err, ProgramError::InvalidArgument)); } #[test] @@ -247,33 +211,12 @@ mod tests { let len_bytes = [1, 0, 0, 0]; pod_slice_bytes[0..4].copy_from_slice(&len_bytes); - // Verify initial length - assert_eq!( - u32::from_le_bytes([ - pod_slice_bytes[0], - pod_slice_bytes[1], - pod_slice_bytes[2], - pod_slice_bytes[3] - ]), - 1 - ); - let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); - pod_slice.push(TestStruct::default()).unwrap(); - // Check length after push - assert_eq!( - u32::from_le_bytes([ - pod_slice_bytes[0], - pod_slice_bytes[1], - pod_slice_bytes[2], - pod_slice_bytes[3] - ]), - 2 - ); + assert_eq!(pod_slice.inner.len(), 1); + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(pod_slice.inner.len(), 2); - // Test that buffer is full - let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); let err = pod_slice .push(TestStruct::default()) .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index 6bed11d6..8623af3a 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -7,7 +7,10 @@ use { solana_program_error::ProgramError, solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, - spl_pod::{list::ListView, primitives::PodU32, slice::PodSlice}, + spl_pod::{ + list::{self, ListView, ListViewable}, + primitives::PodU32, + }, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, std::future::Future, }; @@ -170,7 +173,7 @@ impl ExtraAccountMetaList { extra_account_metas: &[ExtraAccountMeta], ) -> Result<(), ProgramError> { let mut state = TlvStateMut::unpack(data).unwrap(); - let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; + let tlv_size = ListView::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { @@ -186,7 +189,7 @@ impl ExtraAccountMetaList { extra_account_metas: &[ExtraAccountMeta], ) -> Result<(), ProgramError> { let mut state = TlvStateMut::unpack(data).unwrap(); - let tlv_size = PodSlice::::size_of(extra_account_metas.len())?; + let tlv_size = ListView::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { @@ -195,22 +198,22 @@ impl ExtraAccountMetaList { Ok(()) } - /// Get the underlying `PodSlice` from an unpacked TLV + /// Get the underlying `ListViewReadOnly` from an unpacked TLV /// /// Due to lifetime annoyances, this function can't just take in the bytes, /// since then we would be returning a reference to a locally created /// `TlvStateBorrowed`. I hope there's a better way to do this! pub fn unpack_with_tlv_state<'a, T: SplDiscriminate>( tlv_state: &'a TlvStateBorrowed, - ) -> Result, ProgramError> { + ) -> Result, ProgramError> { let bytes = tlv_state.get_first_bytes::()?; - PodSlice::::unpack(bytes) + ListView::::unpack(bytes) } /// Get the byte size required to hold `num_items` items pub fn size_of(num_items: usize) -> Result { Ok(TlvStateBorrowed::get_base_len() - .saturating_add(PodSlice::::size_of(num_items)?)) + .saturating_add(ListView::::size_of(num_items)?)) } /// Checks provided account infos against validation data, using @@ -227,7 +230,7 @@ impl ExtraAccountMetaList { ) -> Result<(), ProgramError> { let state = TlvStateBorrowed::unpack(data).unwrap(); let extra_meta_list = ExtraAccountMetaList::unpack_with_tlv_state::(&state)?; - let extra_account_metas = extra_meta_list.data(); + let extra_account_metas = extra_meta_list.as_slice(); let initial_accounts_len = account_infos.len() - extra_account_metas.len(); @@ -281,7 +284,7 @@ impl ExtraAccountMetaList { { let state = TlvStateBorrowed::unpack(data)?; let bytes = state.get_first_bytes::()?; - let extra_account_metas = PodSlice::::unpack(bytes)?; + let extra_account_metas = ListView::::unpack(bytes)?; // Fetch account data for each of the instruction accounts let mut account_key_datas = vec![]; @@ -294,7 +297,7 @@ impl ExtraAccountMetaList { account_key_datas.push((meta.pubkey, account_data)); } - for extra_meta in extra_account_metas.data().iter() { + for extra_meta in extra_account_metas.as_slice().iter() { let mut meta = extra_meta.resolve(&instruction.data, &instruction.program_id, |usize| { account_key_datas @@ -326,9 +329,9 @@ impl ExtraAccountMetaList { ) -> Result<(), ProgramError> { let state = TlvStateBorrowed::unpack(data)?; let bytes = state.get_first_bytes::()?; - let extra_account_metas = PodSlice::::unpack(bytes)?; + let extra_account_metas = ListView::::unpack(bytes)?; - for extra_meta in extra_account_metas.data().iter() { + for extra_meta in extra_account_metas.as_slice().iter() { let mut meta = { // Create a list of `Ref`s so we can reference account data in the // resolution step @@ -1460,7 +1463,7 @@ mod tests { let state = TlvStateBorrowed::unpack(buffer).unwrap(); let unpacked_metas_pod = ExtraAccountMetaList::unpack_with_tlv_state::(&state).unwrap(); - let unpacked_metas = unpacked_metas_pod.data(); + let unpacked_metas = unpacked_metas_pod.as_slice(); assert_eq!( unpacked_metas, updated_metas, "The ExtraAccountMetas in the buffer should match the expected ones." From 376e2bfe908888cae517c0535b9bc990b403a161 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Mon, 28 Jul 2025 10:13:15 +0200 Subject: [PATCH 6/7] Add sizing methods to trait --- pod/src/list/list_view_mut.rs | 33 ++++++++++++++++++++++++++++- pod/src/list/list_view_read_only.rs | 16 ++++++++++++++ pod/src/list/list_viewable.rs | 19 ++++++++++++++++- 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/pod/src/list/list_view_mut.rs b/pod/src/list/list_view_mut.rs index 8980ea32..52902c90 100644 --- a/pod/src/list/list_view_mut.rs +++ b/pod/src/list/list_view_mut.rs @@ -61,6 +61,7 @@ impl ListViewMut<'_, T, L> { impl ListViewable for ListViewMut<'_, T, L> { type Item = T; + type Length = L; fn len(&self) -> usize { (*self.length).into() @@ -81,7 +82,7 @@ mod tests { super::*, crate::{ list::{ListView, ListViewable}, - primitives::{PodU32, PodU64}, + primitives::{PodU16, PodU32, PodU64}, }, bytemuck_derive::{Pod, Zeroable}, }; @@ -303,4 +304,34 @@ mod tests { // Verify the size of the length field. assert_eq!(size_of_val(view.length), size_of::()); } + + #[test] + fn test_bytes_used_and_allocated_mut() { + // capacity 3, start empty + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + // Empty view + assert_eq!( + view.bytes_used().unwrap(), + ListView::::size_of(0).unwrap() + ); + assert_eq!( + view.bytes_allocated().unwrap(), + ListView::::size_of(view.capacity()).unwrap() + ); + + // After pushing elements + view.push(TestStruct::new(1, 2)).unwrap(); + view.push(TestStruct::new(3, 4)).unwrap(); + view.push(TestStruct::new(5, 6)).unwrap(); + assert_eq!( + view.bytes_used().unwrap(), + ListView::::size_of(3).unwrap() + ); + assert_eq!( + view.bytes_allocated().unwrap(), + ListView::::size_of(view.capacity()).unwrap() + ); + } } diff --git a/pod/src/list/list_view_read_only.rs b/pod/src/list/list_view_read_only.rs index 6fce0bde..7ae6d5ce 100644 --- a/pod/src/list/list_view_read_only.rs +++ b/pod/src/list/list_view_read_only.rs @@ -14,6 +14,7 @@ pub struct ListViewReadOnly<'data, T: Pod, L: PodLength = PodU64> { impl ListViewable for ListViewReadOnly<'_, T, L> { type Item = T; + type Length = L; fn len(&self) -> usize { (*self.length).into() @@ -153,4 +154,19 @@ mod tests { assert_eq!(view.capacity(), 4); assert_eq!(view.as_slice(), &items[..]); } + + #[test] + fn test_bytes_used_and_allocated() { + // 3 live elements, capacity 5 + let items = [10u32, 20, 30]; + let capacity = 5; + let buffer = build_test_buffer::(items.len(), capacity, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + let expected_used = ListView::::size_of(view.len()).unwrap(); + let expected_cap = ListView::::size_of(view.capacity()).unwrap(); + + assert_eq!(view.bytes_used().unwrap(), expected_used); + assert_eq!(view.bytes_allocated().unwrap(), expected_cap); + } } diff --git a/pod/src/list/list_viewable.rs b/pod/src/list/list_viewable.rs index b3144a45..2e823423 100644 --- a/pod/src/list/list_viewable.rs +++ b/pod/src/list/list_viewable.rs @@ -1,10 +1,17 @@ -use {bytemuck::Pod, std::slice::Iter}; +use { + crate::{list::ListView, pod_length::PodLength}, + bytemuck::Pod, + solana_program_error::ProgramError, + std::slice::Iter, +}; /// A trait to abstract the shared, read-only behavior /// between `ListViewReadOnly` and `ListViewMut`. pub trait ListViewable { /// The type of the items stored in the list. type Item: Pod; + /// Length prefix type used (`PodU16`, `PodU32`, …). + type Length: PodLength; /// Returns the number of items in the list. fn len(&self) -> usize; @@ -24,4 +31,14 @@ pub trait ListViewable { fn iter(&self) -> Iter<'_, Self::Item> { self.as_slice().iter() } + + /// Returns the number of **bytes currently occupied** by the live elements + fn bytes_used(&self) -> Result { + ListView::::size_of(self.len()) + } + + /// Returns the number of **bytes reserved** by the entire backing buffer. + fn bytes_allocated(&self) -> Result { + ListView::::size_of(self.capacity()) + } } From 5865b7fab9d058ba4198f1cd1a0d8a432721d9ba Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Thu, 31 Jul 2025 16:41:17 +0200 Subject: [PATCH 7/7] Review updates --- .../list/{list_viewable.rs => list_trait.rs} | 2 +- pod/src/list/list_view.rs | 118 ++++++++++++------ pod/src/list/list_view_mut.rs | 15 ++- pod/src/list/list_view_read_only.rs | 22 ++-- pod/src/list/mod.rs | 6 +- pod/src/pod_length.rs | 3 +- pod/src/slice.rs | 10 +- tlv-account-resolution/src/state.rs | 18 +-- 8 files changed, 117 insertions(+), 77 deletions(-) rename pod/src/list/{list_viewable.rs => list_trait.rs} (98%) diff --git a/pod/src/list/list_viewable.rs b/pod/src/list/list_trait.rs similarity index 98% rename from pod/src/list/list_viewable.rs rename to pod/src/list/list_trait.rs index 2e823423..f3d27b70 100644 --- a/pod/src/list/list_viewable.rs +++ b/pod/src/list/list_trait.rs @@ -7,7 +7,7 @@ use { /// A trait to abstract the shared, read-only behavior /// between `ListViewReadOnly` and `ListViewMut`. -pub trait ListViewable { +pub trait List { /// The type of the items stored in the list. type Item: Pod; /// Length prefix type used (`PodU16`, `PodU32`, …). diff --git a/pod/src/list/list_view.rs b/pod/src/list/list_view.rs index 53be0fdc..1a3419ed 100644 --- a/pod/src/list/list_view.rs +++ b/pod/src/list/list_view.rs @@ -8,7 +8,7 @@ use { error::PodSliceError, list::{list_view_mut::ListViewMut, list_view_read_only::ListViewReadOnly}, pod_length::PodLength, - primitives::PodU64, + primitives::PodU32, }, bytemuck::Pod, solana_program_error::ProgramError, @@ -35,12 +35,13 @@ use { /// The structure assumes the underlying byte buffer is formatted as follows: /// 1. **Length**: A length field of type `L` at the beginning of the buffer, /// indicating the number of currently active elements in the collection. -/// Defaults to `PodU64` so the offset is then compatible with 1, 2, 4 and 8 bytes. +/// Defaults to `PodU32`. The implementation uses padding to ensure that the +/// data is correctly aligned for any `Pod` type. /// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. /// 3. **Data**: The remaining part of the buffer, which is treated as a slice /// of `T` elements. The capacity of the collection is the number of `T` /// elements that can fit into this data portion. -pub struct ListView(PhantomData<(T, L)>); +pub struct ListView(PhantomData<(T, L)>); struct Layout { length_range: Range, @@ -51,10 +52,11 @@ impl ListView { /// Calculate the total byte size for a `ListView` holding `num_items`. /// This includes the length prefix, padding, and data. pub fn size_of(num_items: usize) -> Result { + let header_padding = Self::header_padding()?; size_of::() .checked_mul(num_items) .and_then(|curr| curr.checked_add(size_of::())) - .and_then(|curr| curr.checked_add(Self::header_padding())) + .and_then(|curr| curr.checked_add(header_padding)) .ok_or_else(|| PodSliceError::CalculationFailure.into()) } @@ -133,7 +135,8 @@ impl ListView { #[inline] fn calculate_layout(buf_len: usize) -> Result { let len_field_end = size_of::(); - let data_start = len_field_end.saturating_add(Self::header_padding()); + let header_padding = Self::header_padding()?; + let data_start = len_field_end.saturating_add(header_padding); if buf_len < data_start { return Err(PodSliceError::BufferTooSmall.into()); @@ -150,13 +153,18 @@ impl ListView { /// The goal is to ensure that the data field `T` starts at a memory offset /// that is a multiple of its alignment requirement. #[inline] - const fn header_padding() -> usize { + fn header_padding() -> Result { + // Enforce that the length prefix type `L` itself does not have alignment requirements + if align_of::() != 1 { + return Err(ProgramError::InvalidArgument); + } + let length_size = size_of::(); let data_align = align_of::(); // No padding is needed for alignments of 0 or 1 if data_align == 0 || data_align == 1 { - return 0; + return Ok(0); } // Find how many bytes `length_size` extends past an alignment boundary @@ -166,9 +174,9 @@ impl ListView { // If already aligned (remainder is 0), no padding is needed. // Otherwise, calculate the distance to the next alignment boundary. if remainder == 0 { - 0 + Ok(0) } else { - data_align.wrapping_sub(remainder) + Ok(data_align.wrapping_sub(remainder)) } } } @@ -178,8 +186,8 @@ mod tests { use { super::*, crate::{ - list::ListViewable, - primitives::{PodU16, PodU32}, + list::List, + primitives::{PodU16, PodU32, PodU64}, }, bytemuck_derive::{Pod as DerivePod, Zeroable}, }; @@ -191,14 +199,14 @@ mod tests { assert_eq!(ListView::::size_of(10).unwrap(), 14); // Case 2: size_of is a multiple of align_of, so no padding needed. - // T = u32 (size 4, align 4), L = PodU64 (size 8). 8 % 4 == 0. - // 10 items * 4 bytes/item + 8 bytes for length = 48 - assert_eq!(ListView::::size_of(10).unwrap(), 48); + // T = u32 (size 4, align 4), L = PodU32 (size 4). 4 % 4 == 0. + // 10 items * 4 bytes/item + 4 bytes for length = 44 + assert_eq!(ListView::::size_of(10).unwrap(), 44); // Case 3: 0 items. Size should just be size_of + padding. // Padding is 0 here. - // 0 items * 4 bytes/item + 8 bytes for length = 8 - assert_eq!(ListView::::size_of(0).unwrap(), 8); + // 0 items * 4 bytes/item + 4 bytes for length = 4 + assert_eq!(ListView::::size_of(0).unwrap(), 4); } #[test] @@ -240,27 +248,59 @@ mod tests { assert_eq!(err, PodSliceError::CalculationFailure.into()); } + #[test] + fn test_fails_with_non_aligned_length_type() { + // A custom `PodLength` type with an alignment of 4 + #[repr(C, align(4))] + #[derive(Debug, Copy, Clone, Zeroable, DerivePod)] + struct TestPodU32(u32); + + // Implement the traits for `PodLength` + impl From for usize { + fn from(val: TestPodU32) -> Self { + val.0 as usize + } + } + impl TryFrom for TestPodU32 { + type Error = PodSliceError; + fn try_from(val: usize) -> Result { + Ok(Self(u32::try_from(val)?)) + } + } + + let mut buf = [0u8; 100]; + + let err_size_of = ListView::::size_of(10).unwrap_err(); + assert_eq!(err_size_of, ProgramError::InvalidArgument); + + let err_unpack = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err_unpack, ProgramError::InvalidArgument); + + let err_init = ListView::::init(&mut buf).unwrap_err(); + assert_eq!(err_init, ProgramError::InvalidArgument); + } + #[test] fn test_padding_calculation() { // `u8` has an alignment of 1, so no padding is ever needed. - assert_eq!(ListView::::header_padding(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); // Zero-Sized Types like `()` have size 0 and align 1, requiring no padding. - assert_eq!(ListView::<(), PodU64>::header_padding(), 0); + assert_eq!(ListView::<(), PodU64>::header_padding().unwrap(), 0); // When length and data have the same alignment. - assert_eq!(ListView::::header_padding(), 0); - assert_eq!(ListView::::header_padding(), 0); - assert_eq!(ListView::::header_padding(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); // When data alignment is smaller than or perfectly divides the length size. - assert_eq!(ListView::::header_padding(), 0); // 8 % 2 = 0 - assert_eq!(ListView::::header_padding(), 0); // 8 % 4 = 0 + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 % 2 = 0 + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 % 4 = 0 // When padding IS needed. - assert_eq!(ListView::::header_padding(), 2); // size_of is 2. To align to 4, need 2 bytes. - assert_eq!(ListView::::header_padding(), 6); // size_of is 2. To align to 8, need 6 bytes. - assert_eq!(ListView::::header_padding(), 4); // size_of is 4. To align to 8, need 4 bytes. + assert_eq!(ListView::::header_padding().unwrap(), 2); // size_of is 2. To align to 4, need 2 bytes. + assert_eq!(ListView::::header_padding().unwrap(), 6); // size_of is 2. To align to 8, need 6 bytes. + assert_eq!(ListView::::header_padding().unwrap(), 4); // size_of is 4. To align to 8, need 4 bytes. // Test with custom, higher alignments. #[repr(C, align(8))] @@ -268,17 +308,17 @@ mod tests { struct Align8(u64); // Test against different length types - assert_eq!(ListView::::header_padding(), 6); // 2 + 6 = 8 - assert_eq!(ListView::::header_padding(), 4); // 4 + 4 = 8 - assert_eq!(ListView::::header_padding(), 0); // 8 is already aligned + assert_eq!(ListView::::header_padding().unwrap(), 6); // 2 + 6 = 8 + assert_eq!(ListView::::header_padding().unwrap(), 4); // 4 + 4 = 8 + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 is already aligned #[repr(C, align(16))] #[derive(DerivePod, Zeroable, Copy, Clone)] struct Align16(u128); - assert_eq!(ListView::::header_padding(), 14); // 2 + 14 = 16 - assert_eq!(ListView::::header_padding(), 12); // 4 + 12 = 16 - assert_eq!(ListView::::header_padding(), 8); // 8 + 8 = 16 + assert_eq!(ListView::::header_padding().unwrap(), 14); // 2 + 14 = 16 + assert_eq!(ListView::::header_padding().unwrap(), 12); // 4 + 12 = 16 + assert_eq!(ListView::::header_padding().unwrap(), 8); // 8 + 8 = 16 } #[test] @@ -313,7 +353,7 @@ mod tests { #[test] fn test_unpack_success_with_padding() { // T = u64 (align 8), L = PodU32 (size 4, align 4). Needs 4 bytes padding. - let padding = ListView::::header_padding(); + let padding = ListView::::header_padding().unwrap(); assert_eq!(padding, 4); let length: u32 = 2; @@ -398,7 +438,7 @@ mod tests { #[test] fn test_unpack_success_different_length_type() { // T = u64 (align 8), L = PodU16 (size 2, align 2). Needs 6 bytes padding. - let padding = ListView::::header_padding(); + let padding = ListView::::header_padding().unwrap(); assert_eq!(padding, 6); let length: u16 = 1; @@ -564,10 +604,10 @@ mod tests { #[test] fn test_init_success_default_length_type() { - // This test uses the default L=PodU64 length type by omitting it. - // T = u32 (align 4), L = PodU64 (size 8). No padding needed as 8 % 4 == 0. + // This test uses the default L=PodU32 length type by omitting it. + // T = u32 (align 4), L = PodU32 (size 4). No padding needed as 4 % 4 == 0. let capacity = 5; - let len_size = size_of::(); // Default L is PodU64 + let len_size = size_of::(); // Default L is PodU32 let buf_size = ListView::::size_of(capacity).unwrap(); let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it @@ -577,8 +617,8 @@ mod tests { assert_eq!(view.capacity(), capacity); assert!(view.is_empty()); - // Check that the underlying buffer's length (a u64) was actually zeroed + // Check that the underlying buffer's length (a u32) was actually zeroed let length_bytes = &buf[0..len_size]; - assert_eq!(length_bytes, &[0u8; 8]); + assert_eq!(length_bytes, &[0u8; 4]); } } diff --git a/pod/src/list/list_view_mut.rs b/pod/src/list/list_view_mut.rs index 52902c90..0f6847b6 100644 --- a/pod/src/list/list_view_mut.rs +++ b/pod/src/list/list_view_mut.rs @@ -2,15 +2,14 @@ use { crate::{ - error::PodSliceError, list::list_viewable::ListViewable, pod_length::PodLength, - primitives::PodU64, + error::PodSliceError, list::list_trait::List, pod_length::PodLength, primitives::PodU32, }, bytemuck::Pod, solana_program_error::ProgramError, }; #[derive(Debug)] -pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU64> { +pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU32> { pub(crate) length: &'data mut L, pub(crate) data: &'data mut [T], pub(crate) capacity: usize, @@ -46,8 +45,8 @@ impl ListViewMut<'_, T, L> { self.data.copy_within(tail_start..len, index); // Store the new length (len - 1) - let last = len.saturating_sub(1); - *self.length = L::try_from(last)?; + let new_len = len.checked_sub(1).unwrap(); + *self.length = L::try_from(new_len)?; Ok(removed_item) } @@ -59,7 +58,7 @@ impl ListViewMut<'_, T, L> { } } -impl ListViewable for ListViewMut<'_, T, L> { +impl List for ListViewMut<'_, T, L> { type Item = T; type Length = L; @@ -81,7 +80,7 @@ mod tests { use { super::*, crate::{ - list::{ListView, ListViewable}, + list::{List, ListView}, primitives::{PodU16, PodU32, PodU64}, }, bytemuck_derive::{Pod, Zeroable}, @@ -302,7 +301,7 @@ mod tests { assert_eq!(view.len(), 0); // Verify the size of the length field. - assert_eq!(size_of_val(view.length), size_of::()); + assert_eq!(size_of_val(view.length), size_of::()); } #[test] diff --git a/pod/src/list/list_view_read_only.rs b/pod/src/list/list_view_read_only.rs index 7ae6d5ce..6e392544 100644 --- a/pod/src/list/list_view_read_only.rs +++ b/pod/src/list/list_view_read_only.rs @@ -1,18 +1,18 @@ //! `ListViewReadOnly`, a read-only, compact, zero-copy array wrapper. use { - crate::{list::list_viewable::ListViewable, pod_length::PodLength, primitives::PodU64}, + crate::{list::list_trait::List, pod_length::PodLength, primitives::PodU32}, bytemuck::Pod, }; #[derive(Debug)] -pub struct ListViewReadOnly<'data, T: Pod, L: PodLength = PodU64> { +pub struct ListViewReadOnly<'data, T: Pod, L: PodLength = PodU32> { pub(crate) length: &'data L, pub(crate) data: &'data [T], pub(crate) capacity: usize, } -impl ListViewable for ListViewReadOnly<'_, T, L> { +impl List for ListViewReadOnly<'_, T, L> { type Item = T; type Length = L; @@ -33,7 +33,11 @@ impl ListViewable for ListViewReadOnly<'_, T, L> { mod tests { use { super::*, - crate::{list::ListView, pod_length::PodLength, primitives::PodU32}, + crate::{ + list::ListView, + pod_length::PodLength, + primitives::{PodU32, PodU64}, + }, bytemuck_derive::{Pod as DerivePod, Zeroable}, std::mem::size_of, }; @@ -70,7 +74,7 @@ mod tests { #[test] fn test_len_and_capacity() { let items = [10u32, 20, 30]; - let buffer = build_test_buffer::(items.len(), 5, &items); + let buffer = build_test_buffer::(items.len(), 5, &items); let view = ListView::::unpack(&buffer).unwrap(); assert_eq!(view.len(), 3); @@ -92,12 +96,12 @@ mod tests { #[test] fn test_is_empty() { // Not empty - let buffer_full = build_test_buffer::(1, 2, &[10]); + let buffer_full = build_test_buffer::(1, 2, &[10]); let view_full = ListView::::unpack(&buffer_full).unwrap(); assert!(!view_full.is_empty()); // Empty - let buffer_empty = build_test_buffer::(0, 2, &[]); + let buffer_empty = build_test_buffer::(0, 2, &[]); let view_empty = ListView::::unpack(&buffer_empty).unwrap(); assert!(view_empty.is_empty()); } @@ -146,7 +150,7 @@ mod tests { assert_eq!(header_size, 16); let items = [TestStruct(123), TestStruct(456)]; - let buffer = build_test_buffer::(items.len(), 4, &items); + let buffer = build_test_buffer::(items.len(), 4, &items); let view = ListView::::unpack(&buffer).unwrap(); // Check if the public API works as expected despite internal padding @@ -160,7 +164,7 @@ mod tests { // 3 live elements, capacity 5 let items = [10u32, 20, 30]; let capacity = 5; - let buffer = build_test_buffer::(items.len(), capacity, &items); + let buffer = build_test_buffer::(items.len(), capacity, &items); let view = ListView::::unpack(&buffer).unwrap(); let expected_used = ListView::::size_of(view.len()).unwrap(); diff --git a/pod/src/list/mod.rs b/pod/src/list/mod.rs index 47463684..56062237 100644 --- a/pod/src/list/mod.rs +++ b/pod/src/list/mod.rs @@ -1,9 +1,9 @@ +mod list_trait; mod list_view; mod list_view_mut; mod list_view_read_only; -mod list_viewable; pub use { - list_view::ListView, list_view_mut::ListViewMut, list_view_read_only::ListViewReadOnly, - list_viewable::ListViewable, + list_trait::List, list_view::ListView, list_view_mut::ListViewMut, + list_view_read_only::ListViewReadOnly, }; diff --git a/pod/src/pod_length.rs b/pod/src/pod_length.rs index bd418dc9..756e4b8b 100644 --- a/pod/src/pod_length.rs +++ b/pod/src/pod_length.rs @@ -28,7 +28,8 @@ macro_rules! impl_pod_length_for { impl From<$PodType> for usize { fn from(pod_val: $PodType) -> Self { let primitive_val = <$PrimitiveType>::from(pod_val); - primitive_val as usize + Self::try_from(primitive_val) + .expect("value out of range for usize on this platform") } } }; diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 8f138d35..1810265b 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -2,7 +2,7 @@ use { crate::{ - list::{ListView, ListViewMut, ListViewReadOnly, ListViewable}, + list::{List, ListView, ListViewMut, ListViewReadOnly}, primitives::PodU32, }, bytemuck::Pod, @@ -11,9 +11,7 @@ use { #[deprecated( since = "0.6.0", - note = "This struct will be removed in the next major release (1.0.0). \ - Please use `ListView` instead. If using with existing data initialized by PodSlice, \ - you need to specify PodU32 length (e.g. ListView::::unpack(bytes))" + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." )] /// Special type for using a slice of `Pod`s in a zero-copy way #[allow(deprecated)] @@ -46,9 +44,7 @@ impl<'data, T: Pod> PodSlice<'data, T> { #[deprecated( since = "0.6.0", - note = "This struct will be removed in the next major release (1.0.0). \ - Please use `ListView` instead. If using with existing data initialized by PodSliceMut, \ - you need to specify PodU32 length (e.g. ListView::::init(bytes))" + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." )] /// Special type for using a slice of mutable `Pod`s in a zero-copy way. /// Uses `ListView` under the hood. diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index 8623af3a..af01ced9 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -8,7 +8,7 @@ use { solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, spl_pod::{ - list::{self, ListView, ListViewable}, + list::{self, List, ListView}, primitives::PodU32, }, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, @@ -173,9 +173,9 @@ impl ExtraAccountMetaList { extra_account_metas: &[ExtraAccountMeta], ) -> Result<(), ProgramError> { let mut state = TlvStateMut::unpack(data).unwrap(); - let tlv_size = ListView::::size_of(extra_account_metas.len())?; + let tlv_size = ListView::::size_of(extra_account_metas.len())?; let (bytes, _) = state.alloc::(tlv_size, false)?; - let mut validation_data = ListView::::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -189,9 +189,9 @@ impl ExtraAccountMetaList { extra_account_metas: &[ExtraAccountMeta], ) -> Result<(), ProgramError> { let mut state = TlvStateMut::unpack(data).unwrap(); - let tlv_size = ListView::::size_of(extra_account_metas.len())?; + let tlv_size = ListView::::size_of(extra_account_metas.len())?; let bytes = state.realloc_first::(tlv_size)?; - let mut validation_data = ListView::::init(bytes)?; + let mut validation_data = ListView::::init(bytes)?; for meta in extra_account_metas { validation_data.push(*meta)?; } @@ -207,13 +207,13 @@ impl ExtraAccountMetaList { tlv_state: &'a TlvStateBorrowed, ) -> Result, ProgramError> { let bytes = tlv_state.get_first_bytes::()?; - ListView::::unpack(bytes) + ListView::::unpack(bytes) } /// Get the byte size required to hold `num_items` items pub fn size_of(num_items: usize) -> Result { Ok(TlvStateBorrowed::get_base_len() - .saturating_add(ListView::::size_of(num_items)?)) + .saturating_add(ListView::::size_of(num_items)?)) } /// Checks provided account infos against validation data, using @@ -284,7 +284,7 @@ impl ExtraAccountMetaList { { let state = TlvStateBorrowed::unpack(data)?; let bytes = state.get_first_bytes::()?; - let extra_account_metas = ListView::::unpack(bytes)?; + let extra_account_metas = ListView::::unpack(bytes)?; // Fetch account data for each of the instruction accounts let mut account_key_datas = vec![]; @@ -329,7 +329,7 @@ impl ExtraAccountMetaList { ) -> Result<(), ProgramError> { let state = TlvStateBorrowed::unpack(data)?; let bytes = state.get_first_bytes::()?; - let extra_account_metas = ListView::::unpack(bytes)?; + let extra_account_metas = ListView::::unpack(bytes)?; for extra_meta in extra_account_metas.as_slice().iter() { let mut meta = {