Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions pod/src/list/list_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,20 @@ use {
crate::{list::ListView, pod_length::PodLength},
bytemuck::Pod,
solana_program_error::ProgramError,
std::slice::Iter,
std::ops::Deref,
};

/// A trait to abstract the shared, read-only behavior
/// between `ListViewReadOnly` and `ListViewMut`.
pub trait List {
pub trait List: Deref<Target = [Self::Item]> {
/// The type of the items stored in the list.
type Item: Pod;
/// Length prefix type used (`PodU16`, `PodU32`, …).
type Length: PodLength;

/// Returns the number of items in the list.
fn len(&self) -> usize;

/// Returns `true` if the list contains no items.
fn is_empty(&self) -> bool {
self.len() == 0
}

/// Returns the total number of items that can be stored in the list.
fn capacity(&self) -> usize;

/// Returns a read-only slice of the items currently in the list.
fn as_slice(&self) -> &[Self::Item];

/// Returns a read-only iterator over the list.
fn iter(&self) -> Iter<'_, Self::Item> {
self.as_slice().iter()
}

/// Returns the number of **bytes currently occupied** by the live elements
fn bytes_used(&self) -> Result<usize, ProgramError> {
ListView::<Self::Item, Self::Length>::size_of(self.len())
Expand Down
20 changes: 10 additions & 10 deletions pod/src/list/list_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,12 @@ mod tests {
let view_ro = ListView::<u32, PodU32>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length as usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(view_ro.as_slice(), &items[..]);
assert_eq!(*view_ro, items[..]);

let view_mut = ListView::<u32, PodU32>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), length as usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(view_mut.as_slice(), &items[..]);
assert_eq!(*view_mut, items[..]);
}

#[test]
Expand Down Expand Up @@ -375,12 +375,12 @@ mod tests {
let view_ro = ListView::<u64, PodU32>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length as usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(view_ro.as_slice(), &items[..]);
assert_eq!(*view_ro, items[..]);

let view_mut = ListView::<u64, PodU32>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), length as usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(view_mut.as_slice(), &items[..]);
assert_eq!(*view_mut, items[..]);
}

#[test]
Expand All @@ -398,13 +398,13 @@ mod tests {
assert_eq!(view_ro.len(), 0);
assert_eq!(view_ro.capacity(), capacity);
assert!(view_ro.is_empty());
assert_eq!(view_ro.as_slice(), &[] as &[u32]);
assert_eq!(&*view_ro, &[] as &[u32]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this kind of syntax gives me flashbacks to doing weird things with Arcs while learning Rust, but it's not the worst thing in the world, since someone can also call .deref() by hand


let view_mut = ListView::<u32, PodU32>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), 0);
assert_eq!(view_mut.capacity(), capacity);
assert!(view_mut.is_empty());
assert_eq!(view_mut.as_slice(), &[] as &[u32]);
assert_eq!(&*view_mut, &[] as &[u32]);
}

#[test]
Expand All @@ -427,12 +427,12 @@ mod tests {
let view_ro = ListView::<u64>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length as usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(view_ro.as_slice(), &items[..]);
assert_eq!(*view_ro, items[..]);

let view_mut = ListView::<u64>::unpack_mut(&mut buf).unwrap();
assert_eq!(view_mut.len(), length as usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(view_mut.as_slice(), &items[..]);
assert_eq!(*view_mut, items[..]);
}

#[test]
Expand Down Expand Up @@ -619,14 +619,14 @@ mod tests {
let view_ro = ListView::<T, $LengthType>::unpack(&buf).unwrap();
assert_eq!(view_ro.len(), length_usize);
assert_eq!(view_ro.capacity(), capacity);
assert_eq!(view_ro.as_slice(), &items[..]);
assert_eq!(*view_ro, items[..]);

// Test mutable view
let mut buf_mut = buf.clone();
let view_mut = ListView::<T, $LengthType>::unpack_mut(&mut buf_mut).unwrap();
assert_eq!(view_mut.len(), length_usize);
assert_eq!(view_mut.capacity(), capacity);
assert_eq!(view_mut.as_slice(), &items[..]);
assert_eq!(*view_mut, items[..]);

// Test init
let mut init_buf = vec![0xFFu8; buf_size];
Expand Down
140 changes: 115 additions & 25 deletions pod/src/list/list_view_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use {
},
bytemuck::Pod,
solana_program_error::ProgramError,
std::ops::{Deref, DerefMut},
};

#[derive(Debug)]
Expand Down Expand Up @@ -50,29 +51,31 @@ impl<T: Pod, L: PodLength> ListViewMut<'_, T, L> {

Ok(removed_item)
}
}

impl<T: Pod, L: PodLength> Deref for ListViewMut<'_, T, L> {
type Target = [T];

fn deref(&self) -> &Self::Target {
let len = (*self.length).into();
&self.data[..len]
}
}

/// Returns a mutable iterator over the current elements
pub fn iter_mut(&mut self) -> std::slice::IterMut<T> {
impl<T: Pod, L: PodLength> DerefMut for ListViewMut<'_, T, L> {
fn deref_mut(&mut self) -> &mut Self::Target {
let len = (*self.length).into();
self.data[..len].iter_mut()
&mut self.data[..len]
}
}

impl<T: Pod, L: PodLength> List for ListViewMut<'_, T, L> {
type Item = T;
type Length = L;

fn len(&self) -> usize {
(*self.length).into()
}

fn capacity(&self) -> usize {
self.capacity
}

fn as_slice(&self) -> &[Self::Item] {
&self.data[..self.len()]
}
}

#[cfg(test)]
Expand All @@ -87,7 +90,7 @@ mod tests {
};

#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Pod, Zeroable)]
struct TestStruct {
a: u64,
b: u32,
Expand Down Expand Up @@ -127,19 +130,19 @@ mod tests {
view.push(item1).unwrap();
assert_eq!(view.len(), 1);
assert!(!view.is_empty());
assert_eq!(view.as_slice(), &[item1]);
assert_eq!(*view, [item1]);

// Push second item
let item2 = TestStruct::new(2, 20);
view.push(item2).unwrap();
assert_eq!(view.len(), 2);
assert_eq!(view.as_slice(), &[item1, item2]);
assert_eq!(*view, [item1, item2]);

// Push third item to fill capacity
let item3 = TestStruct::new(3, 30);
view.push(item3).unwrap();
assert_eq!(view.len(), 3);
assert_eq!(view.as_slice(), &[item1, item2, item3]);
assert_eq!(*view, [item1, item2, item3]);

// Try to push beyond capacity
let item4 = TestStruct::new(4, 40);
Expand All @@ -148,7 +151,7 @@ mod tests {

// Ensure state is unchanged
assert_eq!(view.len(), 3);
assert_eq!(view.as_slice(), &[item1, item2, item3]);
assert_eq!(*view, [item1, item2, item3]);
}

#[test]
Expand All @@ -166,32 +169,32 @@ mod tests {
view.push(item4).unwrap();

assert_eq!(view.len(), 4);
assert_eq!(view.as_slice(), &[item1, item2, item3, item4]);
assert_eq!(*view, [item1, item2, item3, item4]);

// Remove from the middle
let removed = view.remove(1).unwrap();
assert_eq!(removed, item2);
assert_eq!(view.len(), 3);
assert_eq!(view.as_slice(), &[item1, item3, item4]);
assert_eq!(*view, [item1, item3, item4]);

// Remove from the end
let removed = view.remove(2).unwrap();
assert_eq!(removed, item4);
assert_eq!(view.len(), 2);
assert_eq!(view.as_slice(), &[item1, item3]);
assert_eq!(*view, [item1, item3]);

// Remove from the start
let removed = view.remove(0).unwrap();
assert_eq!(removed, item1);
assert_eq!(view.len(), 1);
assert_eq!(view.as_slice(), &[item3]);
assert_eq!(*view, [item3]);

// Remove the last element
let removed = view.remove(0).unwrap();
assert_eq!(removed, item3);
assert_eq!(view.len(), 0);
assert!(view.is_empty());
assert_eq!(view.as_slice(), &[]);
assert_eq!(*view, []);
}

#[test]
Expand Down Expand Up @@ -248,10 +251,7 @@ mod tests {

// Check that the underlying data is modified
assert_eq!(view.len(), 3);
assert_eq!(
view.as_slice(),
&[expected_item1, expected_item2, expected_item3]
);
assert_eq!(*view, [expected_item1, expected_item2, expected_item3]);

// Check that iter_mut only iterates over `len` items, not `capacity`
assert_eq!(view.iter_mut().count(), 3);
Expand Down Expand Up @@ -333,4 +333,94 @@ mod tests {
ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
);
}
#[test]
fn test_get_and_get_mut() {
let mut buffer = vec![];
let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);

let item0 = TestStruct::new(1, 10);
let item1 = TestStruct::new(2, 20);
view.push(item0).unwrap();
view.push(item1).unwrap();

// Test get()
assert_eq!(view.first(), Some(&item0));
assert_eq!(view.get(1), Some(&item1));
assert_eq!(view.get(2), None); // out of bounds
assert_eq!(view.get(100), None); // way out of bounds

// Test get_mut() to modify an item
let modified_item0 = TestStruct::new(111, 110);
let item_ref = view.get_mut(0).unwrap();
*item_ref = modified_item0;

// Verify the modification
assert_eq!(view.first(), Some(&modified_item0));
assert_eq!(*view, [modified_item0, item1]);

// Test get_mut() out of bounds
assert_eq!(view.get_mut(2), None);
}

#[test]
fn test_mutable_access_via_indexing() {
let mut buffer = vec![];
let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);

let item0 = TestStruct::new(1, 10);
let item1 = TestStruct::new(2, 20);
view.push(item0).unwrap();
view.push(item1).unwrap();

assert_eq!(view.len(), 2);

// Modify via the mutable slice
view[0].a = 99;

let expected_item0 = TestStruct::new(99, 10);
assert_eq!(view.first(), Some(&expected_item0));
assert_eq!(*view, [expected_item0, item1]);
}

#[test]
fn test_sort_by() {
let mut buffer = vec![];
let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 5);

let item0 = TestStruct::new(5, 1);
let item1 = TestStruct::new(2, 2);
let item2 = TestStruct::new(5, 3);
let item3 = TestStruct::new(1, 4);
let item4 = TestStruct::new(2, 5);

view.push(item0).unwrap();
view.push(item1).unwrap();
view.push(item2).unwrap();
view.push(item3).unwrap();
view.push(item4).unwrap();

// Sort by `b` field in descending order.
view.sort_by(|a, b| b.b.cmp(&a.b));
let expected_order_by_b_desc = [
item4, // b: 5
item3, // b: 4
item2, // b: 3
item1, // b: 2
item0, // b: 1
];
assert_eq!(*view, expected_order_by_b_desc);

// Now, sort by `a` in ascending order. A stable sort preserves the relative
// order of equal elements from the previous state of the list.
view.sort_by(|x, y| x.a.cmp(&y.a));

let expected_order_by_a_stable = [
item3, // a: 1
item4, // a: 2 (was before item1 in the previous state)
item1, // a: 2
item2, // a: 5 (was before item0 in the previous state)
item0, // a: 5
];
assert_eq!(*view, expected_order_by_a_stable);
}
}
Loading