diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index be4934ac55d..a36f579f08b 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -962,10 +962,185 @@ pub fn bitmap_to_ranges(bitmap: &RoaringBitmap) -> Vec> { ranges } +/// A set of stable row ids backed by a 64-bit Roaring bitmap. +/// +/// This is a thin wrapper around [`RoaringTreemap`]. It represents a +/// collection of unique row ids and provides the common row-set +/// operations defined by [`RowSetOps`]. +#[derive(Clone, Debug, Default, PartialEq)] +pub struct RowIdSet { + inner: RoaringTreemap, +} + +impl RowIdSet { + /// Creates an empty set of row ids. + pub fn new() -> Self { + Self::default() + } + /// Returns an iterator over the contained row ids in ascending order. + pub fn iter(&self) -> impl Iterator + '_ { + self.inner.iter() + } + /// Returns the union of `self` and `other`. + pub fn union(mut self, other: &Self) -> Self { + self.inner |= &other.inner; + self + } + /// Returns the set difference `self \\ other`. + pub fn difference(mut self, other: &Self) -> Self { + self.inner -= &other.inner; + self + } +} + +impl RowSetOps for RowIdSet { + type Row = u64; + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + fn len(&self) -> Option { + Some(self.inner.len()) + } + fn remove(&mut self, row: Self::Row) -> bool { + self.inner.remove(row) + } + fn contains(&self, row: Self::Row) -> bool { + self.inner.contains(row) + } + fn union_all(other: &[&Self]) -> Self { + let mut result = other + .first() + .map_or(Self::default(), |&first| first.clone()); + for set in other { + result.inner |= &set.inner; + } + result + } + #[track_caller] + fn from_sorted_iter(iter: I) -> Result + where + I: IntoIterator, + { + let mut inner = RoaringTreemap::new(); + let mut last: Option = None; + for value in iter { + if let Some(prev) = last { + if value < prev { + return Err(Error::Internal { + message: "RowIdSet::from_sorted_iter called with non-sorted input" + .to_string(), + // Use the caller location since we aren't the one that got it out of order + location: std::panic::Location::caller().to_snafu_location(), + }); + } + } + inner.insert(value); + last = Some(value); + } + Ok(Self { inner }) + } +} + +/// A mask over stable row ids based on an allow-list or block-list. +/// +/// The semantics mirror [`RowAddrMask`], but operate on stable +/// row ids instead of physical row addresses. +#[derive(Clone, Debug, PartialEq)] +pub enum RowIdMask { + /// Only the ids in the set are selected. + AllowList(RowIdSet), + /// All ids are selected except those in the set. + BlockList(RowIdSet), +} + +impl Default for RowIdMask { + fn default() -> Self { + // Empty block list means all rows are allowed + Self::BlockList(RowIdSet::default()) + } +} +impl RowIdMask { + /// Create a mask allowing all rows, this is an alias for [`Default`]. + pub fn all_rows() -> Self { + Self::default() + } + /// Create a mask that doesn't allow any row id. + pub fn allow_nothing() -> Self { + Self::AllowList(RowIdSet::default()) + } + /// Create a mask from an allow list. + pub fn from_allowed(allow_list: RowIdSet) -> Self { + Self::AllowList(allow_list) + } + /// Create a mask from a block list. + pub fn from_block(block_list: RowIdSet) -> Self { + Self::BlockList(block_list) + } + /// True if the row id is selected by the mask, false otherwise. + pub fn selected(&self, row_id: u64) -> bool { + match self { + Self::AllowList(allow_list) => allow_list.contains(row_id), + Self::BlockList(block_list) => !block_list.contains(row_id), + } + } + /// Return the indices of the input row ids that are selected by the mask. + pub fn selected_indices<'a>(&self, row_ids: impl Iterator + 'a) -> Vec { + row_ids + .enumerate() + .filter_map(|(idx, row_id)| { + if self.selected(*row_id) { + Some(idx as u64) + } else { + None + } + }) + .collect() + } + /// Also block the given ids. + /// + /// * `AllowList(a)` -> `AllowList(a \\ block_list)` + /// * `BlockList(b)` -> `BlockList(b union block_list)` + pub fn also_block(self, block_list: RowIdSet) -> Self { + match self { + Self::AllowList(allow_list) => Self::AllowList(allow_list.difference(&block_list)), + Self::BlockList(existing) => Self::BlockList(existing.union(&block_list)), + } + } + /// Also allow the given ids. + /// + /// * `AllowList(a)` -> `AllowList(a union allow_list)` + /// * `BlockList(b)` -> `BlockList(b \\ allow_list)` + pub fn also_allow(self, allow_list: RowIdSet) -> Self { + match self { + Self::AllowList(existing) => Self::AllowList(existing.union(&allow_list)), + Self::BlockList(block_list) => Self::BlockList(block_list.difference(&allow_list)), + } + } + /// Return the maximum number of row ids that could be selected by this mask. + /// + /// Will be `None` if this is a `BlockList` (unbounded). + pub fn max_len(&self) -> Option { + match self { + Self::AllowList(selection) => selection.len(), + Self::BlockList(_) => None, + } + } + /// Iterate over the row ids that are selected by the mask. + /// + /// This is only possible if this is an `AllowList`. For a `BlockList` + /// the domain of possible row ids is unbounded. + pub fn iter_ids(&self) -> Option + '_>> { + match self { + Self::AllowList(allow_list) => Some(Box::new(allow_list.iter())), + Self::BlockList(_) => None, + } + } +} + #[cfg(test)] mod tests { use super::*; - use proptest::prop_assert_eq; + use proptest::{prop_assert, prop_assert_eq}; fn rows(ids: &[u64]) -> RowAddrTreeMap { RowAddrTreeMap::from_iter(ids) @@ -1860,4 +2035,358 @@ mod tests { assert!(map.contains(id)); } } + + // ============================================================================ + // Tests for RowIdSet + // ============================================================================ + + fn row_ids(ids: &[u64]) -> RowIdSet { + let mut set = RowIdSet::new(); + for &id in ids { + set.inner.insert(id); + } + set + } + + #[test] + fn test_row_id_set_construction() { + let set = RowIdSet::new(); + assert!(set.is_empty()); + assert_eq!(set.len(), Some(0)); + + let set = row_ids(&[10, 20, 30]); + assert!(!set.is_empty()); + assert_eq!(set.len(), Some(3)); + assert!(set.contains(10)); + assert!(set.contains(20)); + assert!(set.contains(30)); + assert!(!set.contains(15)); + } + + #[test] + fn test_row_id_set_remove() { + let mut set = row_ids(&[10, 20, 30]); + + assert!(!set.remove(15)); // Not present + assert_eq!(set.len(), Some(3)); + + assert!(set.remove(20)); // Present + assert_eq!(set.len(), Some(2)); + assert!(!set.contains(20)); + assert!(set.contains(10)); + assert!(set.contains(30)); + + assert!(!set.remove(20)); // Already removed + } + + #[test] + fn test_row_id_set_union() { + let set1 = row_ids(&[10, 20, 30]); + let set2 = row_ids(&[20, 30, 40]); + + let result = set1.union(&set2); + assert_eq!(result.len(), Some(4)); + for id in [10, 20, 30, 40] { + assert!(result.contains(id)); + } + } + + #[test] + fn test_row_id_set_difference() { + let set1 = row_ids(&[10, 20, 30, 40]); + let set2 = row_ids(&[20, 40]); + + let result = set1.difference(&set2); + assert_eq!(result.len(), Some(2)); + assert!(result.contains(10)); + assert!(result.contains(30)); + assert!(!result.contains(20)); + assert!(!result.contains(40)); + } + + #[test] + fn test_row_id_set_union_all() { + let set1 = row_ids(&[10, 20]); + let set2 = row_ids(&[20, 30]); + let set3 = row_ids(&[30, 40]); + + let result = RowIdSet::union_all(&[&set1, &set2, &set3]); + assert_eq!(result.len(), Some(4)); + for id in [10, 20, 30, 40] { + assert!(result.contains(id)); + } + + // Empty slice should return empty set + let result = RowIdSet::union_all(&[]); + assert!(result.is_empty()); + } + + #[test] + fn test_row_id_set_iter() { + let set = row_ids(&[10, 20, 30]); + let collected: Vec = set.iter().collect(); + assert_eq!(collected, vec![10, 20, 30]); + + let empty = RowIdSet::new(); + assert_eq!(empty.iter().count(), 0); + } + + #[test] + fn test_row_id_set_from_sorted_iter() { + // Valid sorted input + let set = RowIdSet::from_sorted_iter([10, 20, 30, 40]).unwrap(); + assert_eq!(set.len(), Some(4)); + for id in [10, 20, 30, 40] { + assert!(set.contains(id)); + } + + // Empty iterator + let set = RowIdSet::from_sorted_iter(std::iter::empty()).unwrap(); + assert!(set.is_empty()); + + // Single element + let set = RowIdSet::from_sorted_iter([42]).unwrap(); + assert_eq!(set.len(), Some(1)); + assert!(set.contains(42)); + } + + #[test] + fn test_row_id_set_from_sorted_iter_unsorted() { + // Non-sorted input should return error + let result = RowIdSet::from_sorted_iter([30, 10, 20]); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("non-sorted")); + } + + #[test] + fn test_row_id_set_large_values() { + // Test with large u64 values + let large_ids = [u64::MAX - 10, u64::MAX - 5, u64::MAX - 1]; + let set = row_ids(&large_ids); + + for &id in &large_ids { + assert!(set.contains(id)); + } + assert!(!set.contains(u64::MAX)); + assert_eq!(set.len(), Some(3)); + } + + // ============================================================================ + // Tests for RowIdMask + // ============================================================================ + + fn assert_row_id_mask_selects(mask: &RowIdMask, selected: &[u64], not_selected: &[u64]) { + for &id in selected { + assert!(mask.selected(id), "Expected row id {} to be selected", id); + } + for &id in not_selected { + assert!( + !mask.selected(id), + "Expected row id {} to NOT be selected", + id + ); + } + } + + #[test] + fn test_row_id_mask_construction() { + let full_mask = RowIdMask::all_rows(); + assert_eq!(full_mask.max_len(), None); + assert_row_id_mask_selects(&full_mask, &[0, 1, 100, u64::MAX - 1], &[]); + + let empty_mask = RowIdMask::allow_nothing(); + assert_eq!(empty_mask.max_len(), Some(0)); + assert_row_id_mask_selects(&empty_mask, &[], &[0, 1, 100]); + + let allow_list = RowIdMask::from_allowed(row_ids(&[10, 20, 30])); + assert_eq!(allow_list.max_len(), Some(3)); + assert_row_id_mask_selects(&allow_list, &[10, 20, 30], &[0, 15, 25, 40]); + + let block_list = RowIdMask::from_block(row_ids(&[10, 20, 30])); + assert_eq!(block_list.max_len(), None); + assert_row_id_mask_selects(&block_list, &[0, 15, 25, 40], &[10, 20, 30]); + } + + #[test] + fn test_row_id_mask_selected_indices() { + // Allow list + let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 40])); + assert!(mask.selected_indices(std::iter::empty()).is_empty()); + assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[1, 3]); + + // Block list + let mask = RowIdMask::from_block(row_ids(&[10, 20, 40])); + assert!(mask.selected_indices(std::iter::empty()).is_empty()); + assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[0, 2]); + } + + #[test] + fn test_row_id_mask_also_allow() { + // Allow list + let mask = RowIdMask::from_allowed(row_ids(&[10, 20])); + let new_mask = mask.also_allow(row_ids(&[20, 30, 40])); + assert_eq!( + new_mask, + RowIdMask::from_allowed(row_ids(&[10, 20, 30, 40])) + ); + + // Block list + let mask = RowIdMask::from_block(row_ids(&[10, 20, 30])); + let new_mask = mask.also_allow(row_ids(&[20, 40])); + assert_eq!(new_mask, RowIdMask::from_block(row_ids(&[10, 30]))); + } + + #[test] + fn test_row_id_mask_also_block() { + // Allow list + let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 30])); + let new_mask = mask.also_block(row_ids(&[20, 40])); + assert_eq!(new_mask, RowIdMask::from_allowed(row_ids(&[10, 30]))); + + // Block list + let mask = RowIdMask::from_block(row_ids(&[10, 20])); + let new_mask = mask.also_block(row_ids(&[20, 30, 40])); + assert_eq!(new_mask, RowIdMask::from_block(row_ids(&[10, 20, 30, 40]))); + } + + #[test] + fn test_row_id_mask_iter_ids() { + // Allow list + let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 30])); + let ids: Vec = mask.iter_ids().unwrap().collect(); + assert_eq!(ids, vec![10, 20, 30]); + + // Empty allow list + let mask = RowIdMask::allow_nothing(); + let iter = mask.iter_ids(); + assert!(iter.is_some()); + assert_eq!(iter.unwrap().count(), 0); + + // Block list + let mask = RowIdMask::from_block(row_ids(&[10, 20, 30])); + assert!(mask.iter_ids().is_none()); + } + + #[test] + fn test_row_id_mask_default() { + let mask = RowIdMask::default(); + // Default should be BlockList with empty set (all rows allowed) + assert_row_id_mask_selects(&mask, &[0, 1, 100, 1000], &[]); + assert_eq!(mask.max_len(), None); + } + + #[test] + fn test_row_id_mask_ops() { + let mask = RowIdMask::default(); + assert_row_id_mask_selects(&mask, &[1, 5, 100], &[]); + + let block_list = mask.also_block(row_ids(&[0, 5, 15])); + assert_row_id_mask_selects(&block_list, &[1, 100], &[5]); + + let allow_list = RowIdMask::from_allowed(row_ids(&[0, 2, 5])); + assert_row_id_mask_selects(&allow_list, &[5], &[1, 100]); + } + + #[test] + fn test_row_id_mask_combined_ops() { + // Test combining allow and block operations + let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 30, 40, 50])); + let mask = mask.also_block(row_ids(&[20, 40])); + assert_row_id_mask_selects(&mask, &[10, 30, 50], &[20, 40]); + + let mask = mask.also_allow(row_ids(&[20, 60])); + assert_row_id_mask_selects(&mask, &[10, 20, 30, 50, 60], &[40]); + } + + #[test] + fn test_row_id_mask_with_large_values() { + let large_ids = [u64::MAX - 10, u64::MAX - 5, u64::MAX - 1]; + + // Allow list with large values + let mask = RowIdMask::from_allowed(row_ids(&large_ids)); + for &id in &large_ids { + assert!(mask.selected(id)); + } + assert!(!mask.selected(u64::MAX)); + assert!(!mask.selected(0)); + + // Block list with large values + let mask = RowIdMask::from_block(row_ids(&large_ids)); + for &id in &large_ids { + assert!(!mask.selected(id)); + } + assert!(mask.selected(u64::MAX)); + assert!(mask.selected(0)); + } + + proptest::proptest! { + #[test] + fn test_row_id_set_from_sorted_iter_proptest( + mut row_ids in proptest::collection::vec(0..u64::MAX, 0..1000) + ) { + row_ids.sort(); + row_ids.dedup(); + let num_rows = row_ids.len(); + let set = RowIdSet::from_sorted_iter(row_ids.clone()).unwrap(); + prop_assert_eq!(set.len(), Some(num_rows as u64)); + for id in row_ids { + prop_assert!(set.contains(id)); + } + } + + #[test] + fn test_row_id_set_union_proptest( + ids1 in proptest::collection::vec(0..u64::MAX, 0..500), + ids2 in proptest::collection::vec(0..u64::MAX, 0..500), + ) { + let set1 = row_ids(&ids1); + let set2 = row_ids(&ids2); + + let result = set1.union(&set2); + + // All ids from both sets should be in result + for id in ids1.iter().chain(ids2.iter()) { + prop_assert!(result.contains(*id)); + } + + // Result size should be union size + let expected_size = ids1.iter().chain(ids2.iter()).collect::>().len(); + prop_assert_eq!(result.len(), Some(expected_size as u64)); + } + + #[test] + fn test_row_id_set_difference_proptest( + ids1 in proptest::collection::vec(0..u64::MAX, 0..500), + ids2 in proptest::collection::vec(0..u64::MAX, 0..500), + ) { + let set1 = row_ids(&ids1); + let set2 = row_ids(&ids2); + + let result = set1.difference(&set2); + + // Items in ids1 but not in ids2 should be in result + for id in &ids1 { + if !ids2.contains(id) { + prop_assert!(result.contains(*id)); + } else { + prop_assert!(!result.contains(*id)); + } + } + } + + #[test] + fn test_row_id_mask_allow_block_proptest( + allow_ids in proptest::collection::vec(0..10000u64, 0..100), + block_ids in proptest::collection::vec(0..10000u64, 0..100), + test_ids in proptest::collection::vec(0..10000u64, 0..50), + ) { + let mask = RowIdMask::from_allowed(row_ids(&allow_ids)) + .also_block(row_ids(&block_ids)); + + for id in test_ids { + let expected = allow_ids.contains(&id) && !block_ids.contains(&id); + prop_assert_eq!(mask.selected(id), expected); + } + } + } }