diff --git a/pod/src/list/list_trait.rs b/pod/src/list/list_trait.rs index f3d27b70..c7ebf374 100644 --- a/pod/src/list/list_trait.rs +++ b/pod/src/list/list_trait.rs @@ -2,36 +2,20 @@ use { crate::{list::ListView, pod_length::PodLength}, bytemuck::Pod, solana_program_error::ProgramError, - std::slice::Iter, + std::ops::Deref, }; /// A trait to abstract the shared, read-only behavior /// between `ListViewReadOnly` and `ListViewMut`. -pub trait List { +pub trait List: Deref { /// The type of the items stored in the list. type Item: Pod; /// Length prefix type used (`PodU16`, `PodU32`, …). type Length: PodLength; - /// Returns the number of items in the list. - fn len(&self) -> usize; - - /// Returns `true` if the list contains no items. - fn is_empty(&self) -> bool { - self.len() == 0 - } - /// Returns the total number of items that can be stored in the list. fn capacity(&self) -> usize; - /// Returns a read-only slice of the items currently in the list. - fn as_slice(&self) -> &[Self::Item]; - - /// Returns a read-only iterator over the list. - fn iter(&self) -> Iter<'_, Self::Item> { - self.as_slice().iter() - } - /// Returns the number of **bytes currently occupied** by the live elements fn bytes_used(&self) -> Result { ListView::::size_of(self.len()) diff --git a/pod/src/list/list_view.rs b/pod/src/list/list_view.rs index a63d758b..23a64391 100644 --- a/pod/src/list/list_view.rs +++ b/pod/src/list/list_view.rs @@ -342,12 +342,12 @@ mod tests { let view_ro = ListView::::unpack(&buf).unwrap(); assert_eq!(view_ro.len(), length as usize); assert_eq!(view_ro.capacity(), capacity); - assert_eq!(view_ro.as_slice(), &items[..]); + assert_eq!(*view_ro, items[..]); let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); assert_eq!(view_mut.len(), length as usize); assert_eq!(view_mut.capacity(), capacity); - assert_eq!(view_mut.as_slice(), &items[..]); + assert_eq!(*view_mut, items[..]); } #[test] @@ -375,12 +375,12 @@ mod tests { let view_ro = ListView::::unpack(&buf).unwrap(); assert_eq!(view_ro.len(), length as usize); assert_eq!(view_ro.capacity(), capacity); - assert_eq!(view_ro.as_slice(), &items[..]); + assert_eq!(*view_ro, items[..]); let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); assert_eq!(view_mut.len(), length as usize); assert_eq!(view_mut.capacity(), capacity); - assert_eq!(view_mut.as_slice(), &items[..]); + assert_eq!(*view_mut, items[..]); } #[test] @@ -398,13 +398,13 @@ mod tests { assert_eq!(view_ro.len(), 0); assert_eq!(view_ro.capacity(), capacity); assert!(view_ro.is_empty()); - assert_eq!(view_ro.as_slice(), &[] as &[u32]); + assert_eq!(&*view_ro, &[] as &[u32]); let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); assert_eq!(view_mut.len(), 0); assert_eq!(view_mut.capacity(), capacity); assert!(view_mut.is_empty()); - assert_eq!(view_mut.as_slice(), &[] as &[u32]); + assert_eq!(&*view_mut, &[] as &[u32]); } #[test] @@ -427,12 +427,12 @@ mod tests { let view_ro = ListView::::unpack(&buf).unwrap(); assert_eq!(view_ro.len(), length as usize); assert_eq!(view_ro.capacity(), capacity); - assert_eq!(view_ro.as_slice(), &items[..]); + assert_eq!(*view_ro, items[..]); let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); assert_eq!(view_mut.len(), length as usize); assert_eq!(view_mut.capacity(), capacity); - assert_eq!(view_mut.as_slice(), &items[..]); + assert_eq!(*view_mut, items[..]); } #[test] @@ -619,14 +619,14 @@ mod tests { let view_ro = ListView::::unpack(&buf).unwrap(); assert_eq!(view_ro.len(), length_usize); assert_eq!(view_ro.capacity(), capacity); - assert_eq!(view_ro.as_slice(), &items[..]); + assert_eq!(*view_ro, items[..]); // Test mutable view let mut buf_mut = buf.clone(); let view_mut = ListView::::unpack_mut(&mut buf_mut).unwrap(); assert_eq!(view_mut.len(), length_usize); assert_eq!(view_mut.capacity(), capacity); - assert_eq!(view_mut.as_slice(), &items[..]); + assert_eq!(*view_mut, items[..]); // Test init let mut init_buf = vec![0xFFu8; buf_size]; diff --git a/pod/src/list/list_view_mut.rs b/pod/src/list/list_view_mut.rs index 0f6847b6..4f0ca49f 100644 --- a/pod/src/list/list_view_mut.rs +++ b/pod/src/list/list_view_mut.rs @@ -6,6 +6,7 @@ use { }, bytemuck::Pod, solana_program_error::ProgramError, + std::ops::{Deref, DerefMut}, }; #[derive(Debug)] @@ -50,11 +51,21 @@ impl ListViewMut<'_, T, L> { Ok(removed_item) } +} + +impl Deref for ListViewMut<'_, T, L> { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + let len = (*self.length).into(); + &self.data[..len] + } +} - /// Returns a mutable iterator over the current elements - pub fn iter_mut(&mut self) -> std::slice::IterMut { +impl DerefMut for ListViewMut<'_, T, L> { + fn deref_mut(&mut self) -> &mut Self::Target { let len = (*self.length).into(); - self.data[..len].iter_mut() + &mut self.data[..len] } } @@ -62,17 +73,9 @@ impl List for ListViewMut<'_, T, L> { type Item = T; type Length = L; - fn len(&self) -> usize { - (*self.length).into() - } - fn capacity(&self) -> usize { self.capacity } - - fn as_slice(&self) -> &[Self::Item] { - &self.data[..self.len()] - } } #[cfg(test)] @@ -87,7 +90,7 @@ mod tests { }; #[repr(C)] - #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Pod, Zeroable)] struct TestStruct { a: u64, b: u32, @@ -127,19 +130,19 @@ mod tests { view.push(item1).unwrap(); assert_eq!(view.len(), 1); assert!(!view.is_empty()); - assert_eq!(view.as_slice(), &[item1]); + assert_eq!(*view, [item1]); // Push second item let item2 = TestStruct::new(2, 20); view.push(item2).unwrap(); assert_eq!(view.len(), 2); - assert_eq!(view.as_slice(), &[item1, item2]); + assert_eq!(*view, [item1, item2]); // Push third item to fill capacity let item3 = TestStruct::new(3, 30); view.push(item3).unwrap(); assert_eq!(view.len(), 3); - assert_eq!(view.as_slice(), &[item1, item2, item3]); + assert_eq!(*view, [item1, item2, item3]); // Try to push beyond capacity let item4 = TestStruct::new(4, 40); @@ -148,7 +151,7 @@ mod tests { // Ensure state is unchanged assert_eq!(view.len(), 3); - assert_eq!(view.as_slice(), &[item1, item2, item3]); + assert_eq!(*view, [item1, item2, item3]); } #[test] @@ -166,32 +169,32 @@ mod tests { view.push(item4).unwrap(); assert_eq!(view.len(), 4); - assert_eq!(view.as_slice(), &[item1, item2, item3, item4]); + assert_eq!(*view, [item1, item2, item3, item4]); // Remove from the middle let removed = view.remove(1).unwrap(); assert_eq!(removed, item2); assert_eq!(view.len(), 3); - assert_eq!(view.as_slice(), &[item1, item3, item4]); + assert_eq!(*view, [item1, item3, item4]); // Remove from the end let removed = view.remove(2).unwrap(); assert_eq!(removed, item4); assert_eq!(view.len(), 2); - assert_eq!(view.as_slice(), &[item1, item3]); + assert_eq!(*view, [item1, item3]); // Remove from the start let removed = view.remove(0).unwrap(); assert_eq!(removed, item1); assert_eq!(view.len(), 1); - assert_eq!(view.as_slice(), &[item3]); + assert_eq!(*view, [item3]); // Remove the last element let removed = view.remove(0).unwrap(); assert_eq!(removed, item3); assert_eq!(view.len(), 0); assert!(view.is_empty()); - assert_eq!(view.as_slice(), &[]); + assert_eq!(*view, []); } #[test] @@ -248,10 +251,7 @@ mod tests { // Check that the underlying data is modified assert_eq!(view.len(), 3); - assert_eq!( - view.as_slice(), - &[expected_item1, expected_item2, expected_item3] - ); + assert_eq!(*view, [expected_item1, expected_item2, expected_item3]); // Check that iter_mut only iterates over `len` items, not `capacity` assert_eq!(view.iter_mut().count(), 3); @@ -333,4 +333,94 @@ mod tests { ListView::::size_of(view.capacity()).unwrap() ); } + #[test] + fn test_get_and_get_mut() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + let item0 = TestStruct::new(1, 10); + let item1 = TestStruct::new(2, 20); + view.push(item0).unwrap(); + view.push(item1).unwrap(); + + // Test get() + assert_eq!(view.first(), Some(&item0)); + assert_eq!(view.get(1), Some(&item1)); + assert_eq!(view.get(2), None); // out of bounds + assert_eq!(view.get(100), None); // way out of bounds + + // Test get_mut() to modify an item + let modified_item0 = TestStruct::new(111, 110); + let item_ref = view.get_mut(0).unwrap(); + *item_ref = modified_item0; + + // Verify the modification + assert_eq!(view.first(), Some(&modified_item0)); + assert_eq!(*view, [modified_item0, item1]); + + // Test get_mut() out of bounds + assert_eq!(view.get_mut(2), None); + } + + #[test] + fn test_mutable_access_via_indexing() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + let item0 = TestStruct::new(1, 10); + let item1 = TestStruct::new(2, 20); + view.push(item0).unwrap(); + view.push(item1).unwrap(); + + assert_eq!(view.len(), 2); + + // Modify via the mutable slice + view[0].a = 99; + + let expected_item0 = TestStruct::new(99, 10); + assert_eq!(view.first(), Some(&expected_item0)); + assert_eq!(*view, [expected_item0, item1]); + } + + #[test] + fn test_sort_by() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 5); + + let item0 = TestStruct::new(5, 1); + let item1 = TestStruct::new(2, 2); + let item2 = TestStruct::new(5, 3); + let item3 = TestStruct::new(1, 4); + let item4 = TestStruct::new(2, 5); + + view.push(item0).unwrap(); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + view.push(item4).unwrap(); + + // Sort by `b` field in descending order. + view.sort_by(|a, b| b.b.cmp(&a.b)); + let expected_order_by_b_desc = [ + item4, // b: 5 + item3, // b: 4 + item2, // b: 3 + item1, // b: 2 + item0, // b: 1 + ]; + assert_eq!(*view, expected_order_by_b_desc); + + // Now, sort by `a` in ascending order. A stable sort preserves the relative + // order of equal elements from the previous state of the list. + view.sort_by(|x, y| x.a.cmp(&y.a)); + + let expected_order_by_a_stable = [ + item3, // a: 1 + item4, // a: 2 (was before item1 in the previous state) + item1, // a: 2 + item2, // a: 5 (was before item0 in the previous state) + item0, // a: 5 + ]; + assert_eq!(*view, expected_order_by_a_stable); + } } diff --git a/pod/src/list/list_view_read_only.rs b/pod/src/list/list_view_read_only.rs index 6e392544..6d44379a 100644 --- a/pod/src/list/list_view_read_only.rs +++ b/pod/src/list/list_view_read_only.rs @@ -3,6 +3,7 @@ use { crate::{list::list_trait::List, pod_length::PodLength, primitives::PodU32}, bytemuck::Pod, + std::ops::Deref, }; #[derive(Debug)] @@ -16,16 +17,17 @@ impl List for ListViewReadOnly<'_, T, L> { type Item = T; type Length = L; - fn len(&self) -> usize { - (*self.length).into() - } - fn capacity(&self) -> usize { self.capacity } +} - fn as_slice(&self) -> &[Self::Item] { - &self.data[..self.len()] +impl Deref for ListViewReadOnly<'_, T, L> { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + let len = (*self.length).into(); + &self.data[..len] } } @@ -89,8 +91,7 @@ mod tests { let view = ListView::::unpack(&buffer).unwrap(); // `as_slice()` should only return the first `len` items. - assert_eq!(view.as_slice(), &items[..]); - assert_eq!(view.as_slice().len(), view.len()); + assert_eq!(*view, items[..]); } #[test] @@ -138,7 +139,7 @@ mod tests { assert_eq!(view.len(), 0); assert_eq!(view.capacity(), 0); assert!(view.is_empty()); - assert_eq!(view.as_slice(), &[]); + assert_eq!(*view, []); } #[test] @@ -156,7 +157,7 @@ mod tests { // Check if the public API works as expected despite internal padding assert_eq!(view.len(), 2); assert_eq!(view.capacity(), 4); - assert_eq!(view.as_slice(), &items[..]); + assert_eq!(*view, items[..]); } #[test] @@ -173,4 +174,29 @@ mod tests { assert_eq!(view.bytes_used().unwrap(), expected_used); assert_eq!(view.bytes_allocated().unwrap(), expected_cap); } + + #[test] + fn test_get() { + let items = [10u32, 20, 30]; + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // Get in-bounds elements + assert_eq!(view.first(), Some(&10u32)); + assert_eq!(view.get(1), Some(&20u32)); + assert_eq!(view.get(2), Some(&30u32)); + + // Get out-of-bounds element (index == len) + assert_eq!(view.get(3), None); + + // Get way out-of-bounds + assert_eq!(view.get(100), None); + } + + #[test] + fn test_get_on_empty_list() { + let buffer = build_test_buffer::(0, 5, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + assert_eq!(view.first(), None); + } } diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 1810265b..a5b01e77 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -2,7 +2,7 @@ use { crate::{ - list::{List, ListView, ListViewMut, ListViewReadOnly}, + list::{ListView, ListViewMut, ListViewReadOnly}, primitives::PodU32, }, bytemuck::Pod, diff --git a/tlv-account-resolution/src/state.rs b/tlv-account-resolution/src/state.rs index e2bd783d..c4041ba7 100644 --- a/tlv-account-resolution/src/state.rs +++ b/tlv-account-resolution/src/state.rs @@ -8,7 +8,7 @@ use { solana_pubkey::Pubkey, spl_discriminator::SplDiscriminate, spl_pod::{ - list::{self, List, ListView}, + list::{self, ListView}, primitives::PodU32, }, spl_type_length_value::state::{TlvState, TlvStateBorrowed, TlvStateMut}, @@ -219,9 +219,8 @@ impl ExtraAccountMetaList { ) -> Result<(), ProgramError> { let state = TlvStateBorrowed::unpack(data).unwrap(); let extra_meta_list = ExtraAccountMetaList::unpack_with_tlv_state::(&state)?; - let extra_account_metas = extra_meta_list.as_slice(); - let initial_accounts_len = account_infos.len() - extra_account_metas.len(); + let initial_accounts_len = account_infos.len() - extra_meta_list.len(); // Convert to `AccountMeta` to check resolved metas let provided_metas = account_infos @@ -229,7 +228,7 @@ impl ExtraAccountMetaList { .map(account_info_to_meta) .collect::>(); - for (i, config) in extra_account_metas.iter().enumerate() { + for (i, config) in extra_meta_list.iter().enumerate() { let meta = { // Create a list of `Ref`s so we can reference account data in the // resolution step @@ -286,7 +285,7 @@ impl ExtraAccountMetaList { account_key_datas.push((meta.pubkey, account_data)); } - for extra_meta in extra_account_metas.as_slice().iter() { + for extra_meta in extra_account_metas.iter() { let mut meta = extra_meta.resolve(&instruction.data, &instruction.program_id, |usize| { account_key_datas @@ -320,7 +319,7 @@ impl ExtraAccountMetaList { let bytes = state.get_first_bytes::()?; let extra_account_metas = ListView::::unpack(bytes)?; - for extra_meta in extra_account_metas.as_slice().iter() { + for extra_meta in extra_account_metas.iter() { let mut meta = { // Create a list of `Ref`s so we can reference account data in the // resolution step @@ -1451,9 +1450,8 @@ mod tests { let state = TlvStateBorrowed::unpack(buffer).unwrap(); let unpacked_metas_pod = ExtraAccountMetaList::unpack_with_tlv_state::(&state).unwrap(); - let unpacked_metas = unpacked_metas_pod.as_slice(); assert_eq!( - unpacked_metas, updated_metas, + &*unpacked_metas_pod, updated_metas, "The ExtraAccountMetas in the buffer should match the expected ones." );