From 2896cb84cccde7ed340ff1ecda545b026606ce99 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Dec 2025 07:05:17 -0800 Subject: [PATCH 1/9] Various btree performance improvements --- Cargo.lock | 1 + rust/lance-arrow/Cargo.toml | 1 + rust/lance-arrow/src/lib.rs | 17 +- rust/lance-core/src/error.rs | 2 +- rust/lance-core/src/utils/mask.rs | 28 +- rust/lance-index/benches/btree.rs | 62 +-- rust/lance-index/src/scalar.rs | 1 - rust/lance-index/src/scalar/btree.rs | 288 ++++++++------ .../src/scalar/{ => btree}/flat.rs | 356 ++++++------------ rust/lance-index/src/scalar/lance_format.rs | 14 +- 10 files changed, 370 insertions(+), 400 deletions(-) rename rust/lance-index/src/scalar/{ => btree}/flat.rs (56%) diff --git a/Cargo.lock b/Cargo.lock index 95923c70bd7..3b496556a4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4495,6 +4495,7 @@ dependencies = [ "arrow-buffer", "arrow-cast", "arrow-data", + "arrow-ord", "arrow-schema", "arrow-select", "bytes", diff --git a/rust/lance-arrow/Cargo.toml b/rust/lance-arrow/Cargo.toml index 1de7b234956..0caedbc47d2 100644 --- a/rust/lance-arrow/Cargo.toml +++ b/rust/lance-arrow/Cargo.toml @@ -18,6 +18,7 @@ arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-cast = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } bytes = { workspace = true } diff --git a/rust/lance-arrow/src/lib.rs b/rust/lance-arrow/src/lib.rs index f3d3e4af90b..97738938ef2 100644 --- a/rust/lance-arrow/src/lib.rs +++ b/rust/lance-arrow/src/lib.rs @@ -18,7 +18,7 @@ use arrow_array::{ }; use arrow_buffer::MutableBuffer; use arrow_data::ArrayDataBuilder; -use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema}; +use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SortOptions}; use arrow_select::{interleave::interleave, take::take}; use rand::prelude::*; @@ -604,6 +604,9 @@ pub trait RecordBatchExt { /// Create a new RecordBatch with compacted memory after slicing. fn shrink_to_fit(&self) -> Result; + + /// Helper method to sort the RecordBatch by a column + fn sort_by_column(&self, column: usize, options: Option) -> Result; } impl RecordBatchExt for RecordBatch { @@ -778,6 +781,18 @@ impl RecordBatchExt for RecordBatch { // Deep copy the sliced record batch, instead of whole batch crate::deepcopy::deep_copy_batch_sliced(self) } + + fn sort_by_column(&self, column: usize, options: Option) -> Result { + if column >= self.num_columns() { + return Err(ArrowError::InvalidArgumentError(format!( + "Column index out of bounds: {}", + column + ))); + } + let column = self.column(column); + let sorted = arrow_ord::sort::sort_to_indices(column, options, None)?; + self.take(&sorted) + } } fn project(struct_array: &StructArray, fields: &Fields) -> Result { diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index 48150db4354..f80dbca4a7b 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -184,7 +184,7 @@ impl LanceOptionExt for Option { } } -trait ToSnafuLocation { +pub trait ToSnafuLocation { fn to_snafu_location(&'static self) -> snafu::Location; } diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index c0d5347026e..901701fdbeb 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -12,7 +12,8 @@ use byteorder::{ReadBytesExt, WriteBytesExt}; use deepsize::DeepSizeOf; use roaring::{MultiOps, RoaringBitmap, RoaringTreemap}; -use crate::Result; +use crate::error::ToSnafuLocation; +use crate::{Error, Result}; use super::address::RowAddress; @@ -595,6 +596,31 @@ impl RowAddrTreeMap { }), }) } + + #[track_caller] + pub fn from_sorted_iter(iter: impl IntoIterator) -> Result { + let mut iter = iter.into_iter().peekable(); + let mut inner = BTreeMap::new(); + + while let Some(row_id) = iter.peek() { + let fragment_id = (row_id >> 32) as u32; + let next_bitmap_iter = iter + .by_ref() + .take_while(|row_id| (row_id >> 32) as u32 == fragment_id) + .map(|row_id| row_id as u32); + let Ok(bitmap) = RoaringBitmap::from_sorted_iter(next_bitmap_iter) else { + return Err(Error::Internal { + message: "RowAddrTreeMap::from_sorted_iter called with non-sorted input" + .to_string(), + // Use the caller location since we aren't the one that got it out of order + location: std::panic::Location::caller().to_snafu_location(), + }); + }; + inner.insert(fragment_id, RowAddrSelection::Partial(bitmap)); + } + + Ok(Self { inner }) + } } impl std::ops::BitOr for RowAddrTreeMap { diff --git a/rust/lance-index/benches/btree.rs b/rust/lance-index/benches/btree.rs index bac06ecd449..f5bed193353 100644 --- a/rust/lance-index/benches/btree.rs +++ b/rust/lance-index/benches/btree.rs @@ -19,7 +19,6 @@ use std::{ time::Duration, }; -use arrow_schema::DataType; use common::{LOW_CARDINALITY_COUNT, TOTAL_ROWS}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; @@ -27,7 +26,6 @@ use lance_core::cache::LanceCache; use lance_index::metrics::NoOpMetricsCollector; use lance_index::pbold; use lance_index::scalar::btree::{train_btree_index, BTreeIndexPlugin, DEFAULT_BTREE_BATCH_SIZE}; -use lance_index::scalar::flat::FlatIndexMetadata; use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::scalar::registry::ScalarIndexPlugin; use lance_index::scalar::{SargableQuery, ScalarIndex}; @@ -107,17 +105,10 @@ async fn create_int_unique_index( use_cache: bool, ) -> Arc { let stream = common::generate_int_unique_stream(); - let sub_index = FlatIndexMetadata::new(DataType::Int64); - - train_btree_index( - stream, - &sub_index, - store.as_ref(), - DEFAULT_BTREE_BATCH_SIZE, - None, - ) - .await - .unwrap(); + + train_btree_index(stream, store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None) + .await + .unwrap(); let cache = get_cache(use_cache, "int_unique"); let details = prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default()).unwrap(); @@ -135,17 +126,10 @@ async fn create_int_low_card_index( use_cache: bool, ) -> Arc { let stream = common::generate_int_low_cardinality_stream(); - let sub_index = FlatIndexMetadata::new(DataType::Int64); - - train_btree_index( - stream, - &sub_index, - store.as_ref(), - DEFAULT_BTREE_BATCH_SIZE, - None, - ) - .await - .unwrap(); + + train_btree_index(stream, store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None) + .await + .unwrap(); let cache = get_cache(use_cache, "int_low_card"); let details = prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default()).unwrap(); @@ -163,17 +147,10 @@ async fn create_string_unique_index( use_cache: bool, ) -> Arc { let stream = common::generate_string_unique_stream(); - let sub_index = FlatIndexMetadata::new(DataType::Utf8); - - train_btree_index( - stream, - &sub_index, - store.as_ref(), - DEFAULT_BTREE_BATCH_SIZE, - None, - ) - .await - .unwrap(); + + train_btree_index(stream, store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None) + .await + .unwrap(); let cache = get_cache(use_cache, "string_unique"); let details = prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default()).unwrap(); @@ -191,17 +168,10 @@ async fn create_string_low_card_index( use_cache: bool, ) -> Arc { let stream = common::generate_string_low_cardinality_stream(); - let sub_index = FlatIndexMetadata::new(DataType::Utf8); - - train_btree_index( - stream, - &sub_index, - store.as_ref(), - DEFAULT_BTREE_BATCH_SIZE, - None, - ) - .await - .unwrap(); + + train_btree_index(stream, store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None) + .await + .unwrap(); let cache = get_cache(use_cache, "string_low_card"); let details = prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default()).unwrap(); diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 98ce994890c..c97efd84d54 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -32,7 +32,6 @@ pub mod bitmap; pub mod bloomfilter; pub mod btree; pub mod expression; -pub mod flat; pub mod inverted; pub mod json; pub mod label_list; diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 6a04dc88d40..e22fd66c733 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -11,10 +11,9 @@ use std::{ }; use super::{ - flat::FlatIndexMetadata, AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter, - MetricsCollector, SargableQuery, ScalarIndex, ScalarIndexParams, SearchResult, + AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter, MetricsCollector, + SargableQuery, ScalarIndex, ScalarIndexParams, SearchResult, }; -use crate::pbold; use crate::{ frag_reuse::FragReuseIndex, scalar::{ @@ -24,6 +23,7 @@ use crate::{ }, }; use crate::{metrics::NoOpMetricsCollector, scalar::registry::TrainingCriteria}; +use crate::{pbold, scalar::btree::flat::FlatIndex}; use crate::{Index, IndexType}; use arrow_array::{new_empty_array, Array, RecordBatch, UInt32Array}; use arrow_schema::{DataType, Field, Schema, SortOptions}; @@ -62,6 +62,8 @@ use serde::{Deserialize, Serialize, Serializer}; use snafu::location; use tracing::info; +mod flat; + const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; @@ -574,13 +576,27 @@ pub struct BTreeLookup { null_pages: Vec, } +enum Matches { + Some(u32), + All(u32), +} + +impl Matches { + fn page_id(&self) -> u32 { + match self { + Matches::Some(page_id) => *page_id, + Matches::All(page_id) => *page_id, + } + } +} + impl BTreeLookup { fn new(tree: BTreeMap>, null_pages: Vec) -> Self { Self { tree, null_pages } } // All pages that could have a value equal to val - fn pages_eq(&self, query: &OrderableScalarValue) -> Vec { + fn pages_eq(&self, query: &OrderableScalarValue) -> Vec { if query.0.is_null() { self.pages_null() } else { @@ -589,10 +605,16 @@ impl BTreeLookup { } // All pages that could have a value equal to one of the values - fn pages_in(&self, values: impl IntoIterator) -> Vec { + fn pages_in(&self, values: impl IntoIterator) -> Vec { + // TODO: Right now we convert all Matches::All into Matches::Some. We could refine this. + // It would improve performance on low cardinality data. let page_lists = values .into_iter() - .map(|val| self.pages_eq(&val)) + .map(|val| { + self.pages_eq(&val) + .into_iter() + .map(|matches| matches.page_id()) + }) .collect::>(); let total_size = page_lists.iter().map(|set| set.len()).sum(); let mut heap = BinaryHeap::with_capacity(total_size); @@ -602,13 +624,16 @@ impl BTreeLookup { let mut all_pages = heap.into_sorted_vec(); all_pages.dedup(); all_pages + .into_iter() + .map(|page_id| Matches::Some(page_id)) + .collect() } // All pages that could have a value in the range fn pages_between( &self, range: (Bound<&OrderableScalarValue>, Bound<&OrderableScalarValue>), - ) -> Vec { + ) -> Vec { // We need to grab a little bit left of the given range because the query might be 7 // and the first page might be something like 5-10. let lower_bound = match range.0 { @@ -662,25 +687,80 @@ impl BTreeLookup { _ => {} } - let candidates = self - .tree - .range((lower_bound, upper_bound)) - .flat_map(|val| val.1); - match lower_bound { - Bound::Unbounded => candidates.map(|val| val.page_number).collect(), - Bound::Included(lower_bound) => candidates - .filter(|val| val.max.cmp(lower_bound) != Ordering::Less) - .map(|val| val.page_number) - .collect(), - Bound::Excluded(lower_bound) => candidates - .filter(|val| val.max.cmp(lower_bound) == Ordering::Greater) - .map(|val| val.page_number) - .collect(), + let mut matches = Vec::new(); + + for (min, page_records) in self.tree.range((lower_bound, upper_bound)) { + for page_record in page_records { + match lower_bound { + Bound::Unbounded => {} + Bound::Included(lower) => { + if page_record.max.cmp(lower) == Ordering::Less { + continue; + } + } + Bound::Excluded(lower) => { + if page_record.max.cmp(lower) != Ordering::Greater { + continue; + } + } + } + // At this point we know the page record matches at least some values. + // We should test to see if ALL values are a match. + + match range.0 { + // range.0 < X therefore if the smallest value is not strictly greater than + // the lower bound we only have partial match + Bound::Excluded(lower) => { + if min.cmp(lower) != Ordering::Greater { + matches.push(Matches::Some(page_record.page_number)); + continue; + } + } + // range.0 <= X therefore if the smallest value is not greater than or equal + // to the lower bound we only have partial match + Bound::Included(lower) => { + if min.cmp(lower) == Ordering::Less { + matches.push(Matches::Some(page_record.page_number)); + continue; + } + } + Bound::Unbounded => {} + } + match range.1 { + // X < range.1 therefore if the largest value is not strictly less than + // the upper bound we only have partial match + Bound::Excluded(upper) => { + if page_record.max.cmp(upper) != Ordering::Less { + matches.push(Matches::Some(page_record.page_number)); + continue; + } + } + // X <= range.1 therefore if the largest value is not less than or equal to + // the upper bound we only have partial match + Bound::Included(upper) => { + if page_record.max.cmp(upper) == Ordering::Greater { + matches.push(Matches::Some(page_record.page_number)); + continue; + } + } + Bound::Unbounded => {} + } + // The min is greater than the lower bound and the max is less than the upper bound + // so we have a full match + matches.push(Matches::All(page_record.page_number)); + } } + + matches } - fn pages_null(&self) -> Vec { - self.null_pages.clone() + fn pages_null(&self) -> Vec { + // TODO: We could keep track of all-null pages and return Matches::All for those. + // This would improve performance on data with lots of nulls. + self.null_pages + .iter() + .map(|page_id| Matches::Some(*page_id)) + .collect() } } @@ -743,7 +823,7 @@ pub struct BTreePageKey { } impl CacheKey for BTreePageKey { - type ValueType = CachedScalarIndex; + type ValueType = FlatIndex; fn key(&self) -> std::borrow::Cow<'_, str> { format!("page-{}", self.page_number).into() @@ -757,7 +837,7 @@ pub struct BTreeIndex { page_lookup: Arc, index_cache: WeakLanceCache, store: Arc, - sub_index: Arc, + data_type: DataType, batch_size: u64, frag_reuse_index: Option>, } @@ -775,8 +855,8 @@ impl BTreeIndex { tree: BTreeMap>, null_pages: Vec, store: Arc, + data_type: DataType, index_cache: WeakLanceCache, - sub_index: Arc, batch_size: u64, frag_reuse_index: Option>, ) -> Self { @@ -784,8 +864,8 @@ impl BTreeIndex { Self { page_lookup, store, + data_type, index_cache, - sub_index, batch_size, frag_reuse_index, } @@ -796,14 +876,12 @@ impl BTreeIndex { page_number: u32, index_reader: LazyIndexReader, metrics: &dyn MetricsCollector, - ) -> Result> { + ) -> Result> { self.index_cache .get_or_insert_with_key(BTreePageKey { page_number }, move || async move { - let result = self.read_page(page_number, index_reader, metrics).await?; - Ok(CachedScalarIndex::new(result)) + self.read_page(page_number, index_reader, metrics).await }) .await - .map(|v| v.as_ref().clone().into_inner()) } async fn read_page( @@ -811,7 +889,7 @@ impl BTreeIndex { page_number: u32, index_reader: LazyIndexReader, metrics: &dyn MetricsCollector, - ) -> Result> { + ) -> Result { metrics.record_part_load(); info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="btree", part_id=page_number); let index_reader = index_reader.get().await?; @@ -822,28 +900,29 @@ impl BTreeIndex { serialized_page = frag_reuse_index_ref.remap_row_ids_record_batch(serialized_page, 1)?; } - let result = self.sub_index.load_subindex(serialized_page).await?; - Ok(result) + Ok(FlatIndex::try_new(serialized_page)?) } async fn search_page( &self, query: &SargableQuery, - page_number: u32, + matches: Matches, index_reader: LazyIndexReader, metrics: &dyn MetricsCollector, ) -> Result { - let subindex = self.lookup_page(page_number, index_reader, metrics).await?; - // TODO: If this is an IN query we can perhaps simplify the subindex query by restricting it to the - // values that might be in the page. E.g. if we are searching for X IN [5, 3, 7] and five is in pages - // 1 and 2 and three is in page 2 and seven is in pages 8 and 9, then when searching page 2 we only need - // to search for X IN [5, 3] - match subindex.search(query, metrics).await? { - SearchResult::Exact(map) => Ok(map), - _ => Err(Error::Internal { - message: "BTree sub-indices need to return exact results".to_string(), - location: location!(), - }), + let subindex = self + .lookup_page(matches.page_id(), index_reader, metrics) + .await?; + + match matches { + Matches::Some(_) => { + // TODO: If this is an IN query we can perhaps simplify the subindex query by restricting it to the + // values that might be in the page. E.g. if we are searching for X IN [5, 3, 7] and five is in pages + // 1 and 2 and three is in page 2 and seven is in pages 8 and 9, then when searching page 2 we only need + // to search for X IN [5, 3] + subindex.search(query, metrics) + } + Matches::All(_) => Ok(subindex.all()), } } @@ -859,13 +938,12 @@ impl BTreeIndex { if data.num_rows() == 0 { let data_type = data.column(0).data_type().clone(); - let sub_index = Arc::new(FlatIndexMetadata::new(data_type)); return Ok(Self::new( map, null_pages, store, + data_type, WeakLanceCache::from(index_cache), - sub_index, batch_size, frag_reuse_index, )); @@ -907,15 +985,12 @@ impl BTreeIndex { let data_type = mins.data_type(); - // TODO: Support other page types? - let sub_index = Arc::new(FlatIndexMetadata::new(data_type.clone())); - Ok(Self::new( map, null_pages, store, + data_type.clone(), WeakLanceCache::from(index_cache), - sub_index, batch_size, frag_reuse_index, )) @@ -946,13 +1021,24 @@ impl BTreeIndex { )?)) } + // For legacy reasons a btree index expects the training input to use value/_rowid + fn train_schema(&self) -> Schema { + let value_field = Field::new(VALUE_COLUMN_NAME, self.data_type.clone(), true); + let row_id_field = Field::new(ROW_ID, DataType::UInt64, false); + Schema::new(vec![value_field, row_id_field]) + } + + // For legacy reasons a btree index expects the serialized schema to be values/ids + fn flat_schema(&self) -> Schema { + let value_field = Field::new(BTREE_VALUES_COLUMN, self.data_type.clone(), true); + let row_id_field = Field::new(BTREE_IDS_COLUMN, DataType::UInt64, false); + Schema::new(vec![value_field, row_id_field]) + } + /// Create a stream of all the data in the index, in the same format used to train the index async fn into_data_stream(self) -> Result { let reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; - let schema = self.sub_index.schema().clone(); - let value_field = schema.field(0).clone().with_name(VALUE_COLUMN_NAME); - let row_id_field = schema.field(1).clone().with_name(ROW_ID); - let new_schema = Arc::new(Schema::new(vec![value_field, row_id_field])); + let new_schema = Arc::new(self.train_schema()); let new_schema_clone = new_schema.clone(); let reader_stream = IndexReaderStream::new(reader, self.batch_size).await; let batches = reader_stream @@ -1085,7 +1171,7 @@ impl Index for BTreeIndex { &BTreePageKey { page_number: page_idx, }, - Arc::new(CachedScalarIndex::new(page)), + Arc::new(page), ) .await; @@ -1131,8 +1217,8 @@ impl Index for BTreeIndex { .await .buffered(self.store.io_parallelism()); while let Some(serialized) = reader_stream.try_next().await? { - let page = self.sub_index.load_subindex(serialized).await?; - frag_ids |= page.calculate_included_frags().await?; + let page = FlatIndex::try_new(serialized)?; + frag_ids |= page.calculate_included_frags()?; } Ok(frag_ids) @@ -1197,16 +1283,15 @@ impl ScalarIndex for BTreeIndex { dest_store: &dyn IndexStore, ) -> Result { // Remap and write the pages - let mut sub_index_file = dest_store - .new_index_file(BTREE_PAGES_NAME, self.sub_index.schema().clone()) - .await?; + let schema = Arc::new(self.flat_schema()); + let mut sub_index_file = dest_store.new_index_file(BTREE_PAGES_NAME, schema).await?; let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; let mut reader_stream = IndexReaderStream::new(sub_index_reader, self.batch_size) .await .buffered(self.store.io_parallelism()); while let Some(serialized) = reader_stream.try_next().await? { - let remapped = self.sub_index.remap_subindex(serialized, mapping).await?; + let remapped = FlatIndex::remap_batch(serialized, mapping)?; sub_index_file.write_record_batch(remapped).await?; } @@ -1234,14 +1319,7 @@ impl ScalarIndex for BTreeIndex { .clone() .combine_old_new(new_data, self.batch_size) .await?; - train_btree_index( - merged_data_source, - self.sub_index.as_ref(), - dest_store, - self.batch_size, - None, - ) - .await?; + train_btree_index(merged_data_source, dest_store, self.batch_size, None).await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pbold::BTreeIndexDetails::default()) @@ -1329,11 +1407,20 @@ struct EncodedBatch { async fn train_btree_page( batch: RecordBatch, batch_idx: u32, - sub_index_trainer: &dyn BTreeSubIndex, writer: &mut dyn IndexWriter, + schema: Arc, ) -> Result { let stats = analyze_batch(&batch)?; - let trained = sub_index_trainer.train(batch).await?; + + // Renames from value/_rowid to values/ids + let trained = RecordBatch::try_new( + schema.clone(), + vec![ + batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?.clone(), + batch.column_by_name(ROW_ID).expect_ok()?.clone(), + ], + )?; + writer.write_record_batch(trained).await?; Ok(EncodedBatch { stats, @@ -1380,7 +1467,6 @@ fn btree_stats_as_batch(stats: Vec, value_type: &DataType) -> Resu /// a work in progress pub async fn train_btree_index( batches_source: SendableRecordBatchStream, - sub_index_trainer: &dyn BTreeSubIndex, index_store: &dyn IndexStore, batch_size: u64, fragment_ids: Option>, @@ -1396,16 +1482,25 @@ pub async fn train_btree_index( } }); + let flat_schema = Arc::new(Schema::new(vec![ + Field::new( + BTREE_VALUES_COLUMN, + batches_source.schema().field(0).data_type().clone(), + true, + ), + Field::new(BTREE_IDS_COLUMN, DataType::UInt64, false), + ])); + let mut sub_index_file; if fragment_mask.is_none() { sub_index_file = index_store - .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) + .new_index_file(BTREE_PAGES_NAME, flat_schema.clone()) .await?; } else { sub_index_file = index_store .new_index_file( part_page_data_file_path(fragment_mask.unwrap()).as_str(), - sub_index_trainer.schema().clone(), + flat_schema.clone(), ) .await?; } @@ -1423,7 +1518,13 @@ pub async fn train_btree_index( while let Some(batch) = batches_source.try_next().await? { encoded_batches.push( - train_btree_page(batch, batch_idx, sub_index_trainer, sub_index_file.as_mut()).await?, + train_btree_page( + batch, + batch_idx, + sub_index_file.as_mut(), + flat_schema.clone(), + ) + .await?, ); batch_idx += 1; } @@ -1968,15 +2069,8 @@ impl ScalarIndexPlugin for BTreeIndexPlugin { .as_any() .downcast_ref::() .unwrap(); - let value_type = data - .schema() - .field_with_name(VALUE_COLUMN_NAME)? - .data_type() - .clone(); - let flat_index_trainer = FlatIndexMetadata::new(value_type); train_btree_index( data, - &flat_index_trainer, index_store, request .parameters @@ -2010,7 +2104,6 @@ mod tests { use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; use arrow_array::{record_batch, FixedSizeListArray}; - use arrow_schema::DataType; use datafusion::{ execution::{SendableRecordBatchStream, TaskContext}, physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, @@ -2032,7 +2125,6 @@ mod tests { metrics::NoOpMetricsCollector, scalar::{ btree::{BTreeIndex, BTREE_PAGES_NAME}, - flat::FlatIndexMetadata, lance_format::LanceIndexStore, IndexStore, SargableQuery, ScalarIndex, SearchResult, }, @@ -2075,9 +2167,8 @@ mod tests { ) .col("_rowid", array::step::()) .into_df_stream(RowCount::from(5000), BatchCount::from(10)); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 5000, None) + train_btree_index(stream, test_store.as_ref(), 5000, None) .await .unwrap(); @@ -2158,9 +2249,7 @@ mod tests { let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); - - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + train_btree_index(stream, test_store.as_ref(), 64, None) .await .unwrap(); @@ -2199,9 +2288,8 @@ mod tests { let stream = stream.map_err(DataFusionError::from); let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); - train_btree_index(stream, &sub_index_trainer, test_store.as_ref(), 64, None) + train_btree_index(stream, test_store.as_ref(), 64, None) .await .unwrap(); @@ -2235,8 +2323,6 @@ mod tests { Arc::new(LanceCache::no_cache()), )); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); - // Method 1: Build complete index directly using the same data // Create deterministic data for comparison - use 2 * DEFAULT_BTREE_BATCH_SIZE for testing let total_count = 2 * DEFAULT_BTREE_BATCH_SIZE; @@ -2251,7 +2337,6 @@ mod tests { train_btree_index( full_data_source, - &sub_index_trainer, full_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None, @@ -2273,7 +2358,6 @@ mod tests { train_btree_index( fragment1_data_source, - &sub_index_trainer, fragment_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, Some(vec![1]), // fragment_id = 1 @@ -2297,7 +2381,6 @@ mod tests { train_btree_index( fragment2_data_source, - &sub_index_trainer, fragment_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, Some(vec![2]), // fragment_id = 2 @@ -2419,8 +2502,6 @@ mod tests { Arc::new(LanceCache::no_cache()), )); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Int32); - // Use 3 * DEFAULT_BTREE_BATCH_SIZE for more comprehensive boundary testing let total_count = 3 * DEFAULT_BTREE_BATCH_SIZE; @@ -2436,7 +2517,6 @@ mod tests { train_btree_index( full_data_source, - &sub_index_trainer, full_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None, @@ -2458,7 +2538,6 @@ mod tests { train_btree_index( fragment1_data_source, - &sub_index_trainer, fragment_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, Some(vec![1]), @@ -2482,7 +2561,6 @@ mod tests { train_btree_index( fragment2_data_source, - &sub_index_trainer, fragment_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, Some(vec![2]), @@ -2506,7 +2584,6 @@ mod tests { train_btree_index( fragment3_data_source, - &sub_index_trainer, fragment_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, Some(vec![3]), @@ -2900,8 +2977,7 @@ mod tests { let stream = Box::pin(RecordBatchStreamAdapter::new(batch.schema(), stream)); // Train the btree index with FlatIndexMetadata as sub-index - let sub_index_trainer = super::FlatIndexMetadata::new(DataType::Int32); - super::train_btree_index(stream, &sub_index_trainer, store.as_ref(), 256, None) + super::train_btree_index(stream, store.as_ref(), 256, None) .await .unwrap(); diff --git a/rust/lance-index/src/scalar/flat.rs b/rust/lance-index/src/scalar/btree/flat.rs similarity index 56% rename from rust/lance-index/src/scalar/flat.rs rename to rust/lance-index/src/scalar/btree/flat.rs index ff6c0bcd11c..78b139e0002 100644 --- a/rust/lance-index/src/scalar/flat.rs +++ b/rust/lance-index/src/scalar/btree/flat.rs @@ -2,30 +2,23 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::collections::HashMap; -use std::{any::Any, ops::Bound, sync::Arc}; +use std::{ops::Bound, sync::Arc}; use arrow_array::{ cast::AsArray, types::UInt64Type, ArrayRef, BooleanArray, RecordBatch, UInt64Array, }; -use arrow_schema::{DataType, Field, Schema}; -use async_trait::async_trait; -use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_physical_expr::expressions::{in_list, lit, Column}; use deepsize::DeepSizeOf; -use lance_core::error::LanceOptionExt; +use lance_arrow::RecordBatchExt; use lance_core::utils::address::RowAddress; use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; -use lance_core::{Error, Result, ROW_ID}; +use lance_core::{Error, Result}; use roaring::RoaringBitmap; use snafu::location; -use super::{btree::BTreeSubIndex, IndexStore, ScalarIndex}; -use super::{AnyQuery, MetricsCollector, SargableQuery, SearchResult}; -use crate::scalar::btree::{BTREE_IDS_COLUMN, BTREE_VALUES_COLUMN}; -use crate::scalar::registry::VALUE_COLUMN_NAME; -use crate::scalar::{CreatedIndex, UpdateCriteria}; -use crate::{Index, IndexType}; +use crate::metrics::MetricsCollector; +use crate::scalar::{AnyQuery, SargableQuery}; /// A flat index is just a batch of value/row-id pairs /// @@ -36,6 +29,8 @@ use crate::{Index, IndexType}; #[derive(Debug)] pub struct FlatIndex { data: Arc, + all_addrs_map: RowAddrTreeMap, + null_addrs_map: RowAddrTreeMap, has_nulls: bool, } @@ -46,184 +41,117 @@ impl DeepSizeOf for FlatIndex { } impl FlatIndex { - fn values(&self) -> &ArrayRef { - self.data.column(0) - } - - fn ids(&self) -> &ArrayRef { - self.data.column(1) - } -} - -fn remap_batch(batch: RecordBatch, mapping: &HashMap>) -> Result { - let row_ids = batch.column(1).as_primitive::(); - let val_idx_and_new_id = row_ids - .values() - .iter() - .enumerate() - .filter_map(|(idx, old_id)| { - mapping - .get(old_id) - .copied() - .unwrap_or(Some(*old_id)) - .map(|new_id| (idx, new_id)) - }) - .collect::>(); - let new_ids = Arc::new(UInt64Array::from_iter_values( - val_idx_and_new_id.iter().copied().map(|(_, new_id)| new_id), - )); - let new_val_indices = UInt64Array::from_iter_values( - val_idx_and_new_id - .into_iter() - .map(|(val_idx, _)| val_idx as u64), - ); - let new_vals = arrow_select::take::take(batch.column(0), &new_val_indices, None)?; - Ok(RecordBatch::try_new( - batch.schema(), - vec![new_vals, new_ids], - )?) -} - -/// Trains a flat index from a record batch of values & ids by simply storing the batch -/// -/// This allows the flat index to be used as a sub-index -#[derive(Debug)] -pub struct FlatIndexMetadata { - schema: Arc, -} - -impl DeepSizeOf for FlatIndexMetadata { - fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { - self.schema.metadata.deep_size_of_children(context) - + self - .schema - .fields + pub fn try_new(data: RecordBatch) -> Result { + // Sort by row id to make bitmap construction more efficient + let data = data.sort_by_column(1, None)?; + let has_nulls = data.column(1).null_count() > 0; + let all_addrs_map = RowAddrTreeMap::from_sorted_iter( + data.column(1) + .as_primitive::() + .values() .iter() - // This undercounts slightly because it doesn't account for the size of the - // field data types - .map(|f| { - std::mem::size_of::() - + f.name().deep_size_of_children(context) - + f.metadata().deep_size_of_children(context) - }) - .sum::() - } -} + .copied(), + )?; -impl FlatIndexMetadata { - pub fn new(value_type: DataType) -> Self { - let schema = Arc::new(Schema::new(vec![ - Field::new(BTREE_VALUES_COLUMN, value_type, true), - Field::new(BTREE_IDS_COLUMN, DataType::UInt64, true), - ])); - Self { schema } - } -} - -#[async_trait] -impl BTreeSubIndex for FlatIndexMetadata { - fn schema(&self) -> &Arc { - &self.schema - } - - async fn train(&self, batch: RecordBatch) -> Result { - // The data source may not call the columns "values" and "row_ids" so we need to replace - // the schema - Ok(RecordBatch::try_new( - self.schema.clone(), - vec![ - batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?.clone(), - batch.column_by_name(ROW_ID).expect_ok()?.clone(), - ], - )?) - } + let null_addrs_map = if has_nulls { + Self::get_null_addrs(&data)? + } else { + RowAddrTreeMap::default() + }; - async fn load_subindex(&self, serialized: RecordBatch) -> Result> { - let has_nulls = serialized.column(0).null_count() > 0; - Ok(Arc::new(FlatIndex { - data: Arc::new(serialized), + Ok(Self { + data: Arc::new(data), + all_addrs_map, + null_addrs_map, has_nulls, - })) - } - - async fn remap_subindex( - &self, - serialized: RecordBatch, - mapping: &HashMap>, - ) -> Result { - remap_batch(serialized, mapping) - } - - async fn retrieve_data(&self, serialized: RecordBatch) -> Result { - Ok(serialized) - } -} - -#[async_trait] -impl Index for FlatIndex { - fn as_any(&self) -> &dyn Any { - self - } - - fn as_index(self: Arc) -> Arc { - self - } - - fn as_vector_index(self: Arc) -> Result> { - Err(Error::NotSupported { - source: "FlatIndex is not vector index".into(), - location: location!(), }) } - fn index_type(&self) -> IndexType { - IndexType::Scalar + fn values(&self) -> &ArrayRef { + self.data.column(0) } - async fn prewarm(&self) -> Result<()> { - // There is nothing to pre-warm - Ok(()) + fn ids(&self) -> &ArrayRef { + self.data.column(1) } - fn statistics(&self) -> Result { - Ok(serde_json::json!({ - "num_values": self.data.num_rows(), - })) + pub fn all(&self) -> NullableRowAddrSet { + // Some rows will be in both sets but that is ok, null trumps true + NullableRowAddrSet::new(self.all_addrs_map.clone(), self.null_addrs_map.clone()) } - async fn calculate_included_frags(&self) -> Result { - let mut frag_ids = self - .ids() - .as_primitive::() + pub fn remap_batch( + batch: RecordBatch, + mapping: &HashMap>, + ) -> Result { + let row_ids = batch.column(1).as_primitive::(); + let val_idx_and_new_id = row_ids + .values() .iter() - .map(|row_id| RowAddress::from(row_id.unwrap()).fragment_id()) + .enumerate() + .filter_map(|(idx, old_id)| { + mapping + .get(old_id) + .copied() + .unwrap_or(Some(*old_id)) + .map(|new_id| (idx, new_id)) + }) .collect::>(); - frag_ids.sort(); - frag_ids.dedup(); - Ok(RoaringBitmap::from_sorted_iter(frag_ids).unwrap()) + let new_ids = Arc::new(UInt64Array::from_iter_values( + val_idx_and_new_id.iter().copied().map(|(_, new_id)| new_id), + )); + let new_val_indices = UInt64Array::from_iter_values( + val_idx_and_new_id + .into_iter() + .map(|(val_idx, _)| val_idx as u64), + ); + let new_vals = arrow_select::take::take(batch.column(0), &new_val_indices, None)?; + Ok(RecordBatch::try_new( + batch.schema(), + vec![new_vals, new_ids], + )?) } -} -#[async_trait] -impl ScalarIndex for FlatIndex { - async fn search( + fn get_null_addrs(sorted_batch: &RecordBatch) -> Result { + let null_mask = arrow::compute::is_null(sorted_batch.column(0))?; + let null_ids = arrow_select::filter::filter(sorted_batch.column(1), &null_mask)?; + let null_ids = null_ids + .as_any() + .downcast_ref::() + .expect("Result of arrow_select::filter::filter did not match input type"); + RowAddrTreeMap::from_sorted_iter(null_ids.values().iter().copied()) + } + + pub fn search( &self, query: &dyn AnyQuery, metrics: &dyn MetricsCollector, - ) -> Result { + ) -> Result { metrics.record_comparisons(self.data.num_rows()); let query = query.as_any().downcast_ref::().unwrap(); // Since we have all the values in memory we can use basic arrow-rs compute // functions to satisfy scalar queries. + + let mut null_is_true = false; let mut predicate = match query { SargableQuery::Equals(value) => { if value.is_null() { - arrow::compute::is_null(self.values())? + // Query is x = NULL, correct SQL behavior is to return all ids as NULL + // We differ a little and return them all as true right now. + return Ok(NullableRowAddrSet::new( + self.null_addrs_map.clone(), + Default::default(), + )); } else { arrow_ord::cmp::eq(self.values(), &value.to_scalar()?)? } } - SargableQuery::IsNull() => arrow::compute::is_null(self.values())?, + SargableQuery::IsNull() => { + return Ok(NullableRowAddrSet::new( + self.null_addrs_map.clone(), + Default::default(), + )); + } SargableQuery::IsIn(values) => { let mut has_null = false; let choices = values @@ -247,6 +175,10 @@ impl ScalarIndex for FlatIndex { .expect("InList evaluation should return boolean array") .clone(); + // If the IN query has nulls, then don't treat the nulls as null. This is a little different + // than SQL behavior. + null_is_true = has_null; + // Arrow's in_list does not handle nulls so we need to join them in here if user asked for them if has_null && self.has_nulls { let nulls = arrow::compute::is_null(self.values())?; @@ -303,23 +235,10 @@ impl ScalarIndex for FlatIndex { // Track null row IDs for Kleene logic // When querying FOR nulls (IS NULL or Equals(null)), don't track them as "null results" // because they are the TRUE result of the query - let null_row_ids = if self.has_nulls - && !matches!(query, SargableQuery::IsNull()) - && !matches!(query, SargableQuery::Equals(val) if val.is_null()) - { - let null_mask = arrow::compute::is_null(self.values())?; - let null_ids = arrow_select::filter::filter(self.ids(), &null_mask)?; - let null_ids = null_ids - .as_any() - .downcast_ref::() - .expect("Result of arrow_select::filter::filter did not match input type"); - if null_ids.is_empty() { - None - } else { - Some(RowAddrTreeMap::from_iter(null_ids.values())) - } + let null_row_ids = if null_is_true { + self.null_addrs_map.clone() } else { - None + Default::default() }; let matching_ids = arrow_select::filter::filter(self.ids(), &predicate)?; @@ -327,40 +246,20 @@ impl ScalarIndex for FlatIndex { .as_any() .downcast_ref::() .expect("Result of arrow_select::filter::filter did not match input type"); - let selected = RowAddrTreeMap::from_iter(matching_ids.values()); - let selection = NullableRowAddrSet::new(selected, null_row_ids.unwrap_or_default()); - Ok(SearchResult::Exact(selection)) - } - - fn can_remap(&self) -> bool { - true - } - - // Same as above, this is dead code at the moment but should work - async fn remap( - &self, - _mapping: &HashMap>, - _dest_store: &dyn IndexStore, - ) -> Result { - unimplemented!() - } - - async fn update( - &self, - _new_data: SendableRecordBatchStream, - _dest_store: &dyn IndexStore, - ) -> Result { - // If this was desired, then you would need to merge new_data and data and write it back out - unimplemented!() - } - - fn update_criteria(&self) -> UpdateCriteria { - unimplemented!() + let selected = RowAddrTreeMap::from_sorted_iter(matching_ids.values().iter().copied())?; + Ok(NullableRowAddrSet::new(selected, null_row_ids)) } - fn derive_index_params(&self) -> Result { - // FlatIndex is used internally and doesn't have user-configurable parameters - unimplemented!("FlatIndex is an internal index type and cannot be recreated") + pub fn calculate_included_frags(&self) -> Result { + let mut frag_ids = self + .ids() + .as_primitive::() + .iter() + .map(|row_id| RowAddress::from(row_id.unwrap()).fragment_id()) + .collect::>(); + frag_ids.sort(); + frag_ids.dedup(); + Ok(RoaringBitmap::from_sorted_iter(frag_ids).unwrap()) } } @@ -383,21 +282,15 @@ mod tests { .into_batch_rows(RowCount::from(4)) .unwrap(); - FlatIndex { - data: Arc::new(batch), - has_nulls: false, - } + FlatIndex::try_new(batch).unwrap() } async fn check_index(query: &SargableQuery, expected: &[u64]) { let index = example_index(); - let actual = index.search(query, &NoOpMetricsCollector).await.unwrap(); - let SearchResult::Exact(actual_row_ids) = actual else { - panic! {"Expected exact search result"} - }; + let actual = index.search(query, &NoOpMetricsCollector).unwrap(); let expected = NullableRowAddrSet::new(RowAddrTreeMap::from_iter(expected), Default::default()); - assert_eq!(actual_row_ids, expected); + assert_eq!(actual, expected); } #[tokio::test] @@ -454,18 +347,19 @@ mod tests { // 3 -> delete // Keep remaining as is let mapping = HashMap::>::from_iter(vec![(0, Some(2000)), (3, None)]); - let metadata = FlatIndexMetadata::new(DataType::Int32); - let remapped = metadata - .remap_subindex((*index.data).clone(), &mapping) - .await - .unwrap(); - - let expected = gen_batch() - .col("values", array::cycle::(vec![10, 100, 1234])) - .col("ids", array::cycle::(vec![5, 2000, 100])) - .into_batch_rows(RowCount::from(3)) - .unwrap(); - assert_eq!(remapped, expected); + let remapped = + FlatIndex::try_new(FlatIndex::remap_batch((*index.data).clone(), &mapping).unwrap()) + .unwrap(); + + let expected = FlatIndex::try_new( + gen_batch() + .col("values", array::cycle::(vec![10, 100, 1234])) + .col("ids", array::cycle::(vec![5, 2000, 100])) + .into_batch_rows(RowCount::from(3)) + .unwrap(), + ) + .unwrap(); + assert_eq!(remapped.data, expected.data); } // It's possible, during compaction, that an entire page of values is deleted. We just serialize @@ -479,11 +373,7 @@ mod tests { (3, None), (100, None), ]); - let metadata = FlatIndexMetadata::new(DataType::Int32); - let remapped = metadata - .remap_subindex((*index.data).clone(), &mapping) - .await - .unwrap(); + let remapped = FlatIndex::remap_batch((*index.data).clone(), &mapping).unwrap(); assert_eq!(remapped.num_rows(), 0); } } diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index 463953ee801..4c8f3cd267e 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -312,7 +312,6 @@ pub mod tests { use crate::scalar::{ bitmap::BitmapIndex, btree::{train_btree_index, DEFAULT_BTREE_BATCH_SIZE}, - flat::FlatIndexMetadata, LabelListQuery, SargableQuery, ScalarIndex, SearchResult, }; @@ -855,17 +854,10 @@ pub mod tests { ])); let data = RecordBatchIterator::new(batches, schema); let data = lance_datafusion::utils::reader_to_stream(Box::new(data)); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Utf8); - train_btree_index( - data, - &sub_index_trainer, - index_store.as_ref(), - DEFAULT_BTREE_BATCH_SIZE, - None, - ) - .await - .unwrap(); + train_btree_index(data, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None) + .await + .unwrap(); let index = BTreeIndexPlugin .load_index( From e16494465cd1f9f00e9f965ccf0a8e16f479acd5 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Dec 2025 10:26:19 -0800 Subject: [PATCH 2/9] Fix bug revealed by python tests with null handling --- python/Cargo.lock | 1 + rust/lance-index/src/scalar/btree.rs | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/python/Cargo.lock b/python/Cargo.lock index 84b271c5d2b..f5d78aa41aa 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -3929,6 +3929,7 @@ dependencies = [ "arrow-buffer", "arrow-cast", "arrow-data", + "arrow-ord", "arrow-schema", "arrow-select", "bytes", diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index e22fd66c733..32d6e10967c 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -707,6 +707,12 @@ impl BTreeLookup { // At this point we know the page record matches at least some values. // We should test to see if ALL values are a match. + if min.0.is_null() || page_record.max.0.is_null() { + // If there are nulls then we just use Matches::Some + matches.push(Matches::Some(page_record.page_number)); + continue; + } + match range.0 { // range.0 < X therefore if the smallest value is not strictly greater than // the lower bound we only have partial match From 6bf1832bacc492bfb112635d2c737c15c4ae4e94 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Dec 2025 10:41:57 -0800 Subject: [PATCH 3/9] Fix benchmark to new code --- rust/lance/benches/scalar_index.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index f24816710bd..4b2098c5f39 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -17,7 +17,6 @@ use lance_datafusion::utils::reader_to_stream; use lance_datagen::{array, gen_batch, BatchCount, RowCount}; use lance_index::scalar::{ btree::{train_btree_index, DEFAULT_BTREE_BATCH_SIZE}, - flat::FlatIndexMetadata, lance_format::LanceIndexStore, registry::ScalarIndexPlugin, IndexStore, SargableQuery, ScalarIndex, SearchResult, @@ -63,11 +62,8 @@ impl BenchmarkFixture { } async fn train_scalar_index(index_store: &Arc) { - let sub_index_trainer = FlatIndexMetadata::new(arrow_schema::DataType::UInt32); - train_btree_index( test_data_stream(), - &sub_index_trainer, index_store.as_ref(), DEFAULT_BTREE_BATCH_SIZE, None, From 7621d0bc6bc13a06ac4704207739c5919bb39c1d Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Dec 2025 15:07:13 -0800 Subject: [PATCH 4/9] Fix bug in RowAddrTreeMap::from_sorted_iter. Add test for this fn too. --- Cargo.lock | 1 + rust/lance-core/Cargo.toml | 1 + rust/lance-core/src/utils/mask.rs | 15 +++++++++++++-- rust/lance-index/src/scalar/btree.rs | 1 + 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3b496556a4f..50186efc17e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4530,6 +4530,7 @@ dependencies = [ "datafusion-sql", "deepsize", "futures", + "itertools 0.13.0", "lance-arrow", "lance-testing", "libc", diff --git a/rust/lance-core/Cargo.toml b/rust/lance-core/Cargo.toml index dd3bfbc5b39..aa1c3f6aebe 100644 --- a/rust/lance-core/Cargo.toml +++ b/rust/lance-core/Cargo.toml @@ -24,6 +24,7 @@ datafusion-common = { workspace = true, optional = true } datafusion-sql = { workspace = true, optional = true } deepsize.workspace = true futures.workspace = true +itertools.workspace = true libc.workspace = true mock_instant.workspace = true moka.workspace = true diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 901701fdbeb..3a0a73cfe01 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -10,6 +10,7 @@ use arrow_array::{Array, BinaryArray, GenericBinaryArray}; use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer}; use byteorder::{ReadBytesExt, WriteBytesExt}; use deepsize::DeepSizeOf; +use itertools::Itertools; use roaring::{MultiOps, RoaringBitmap, RoaringTreemap}; use crate::error::ToSnafuLocation; @@ -605,8 +606,7 @@ impl RowAddrTreeMap { while let Some(row_id) = iter.peek() { let fragment_id = (row_id >> 32) as u32; let next_bitmap_iter = iter - .by_ref() - .take_while(|row_id| (row_id >> 32) as u32 == fragment_id) + .peeking_take_while(|row_id| (row_id >> 32) as u32 == fragment_id) .map(|row_id| row_id as u32); let Ok(bitmap) = RoaringBitmap::from_sorted_iter(next_bitmap_iter) else { return Err(Error::Internal { @@ -1530,6 +1530,17 @@ mod tests { prop_assert_eq!(expected, left); } + #[test] + fn test_from_sorted_iter( + mut rows in proptest::collection::vec(0..u64::MAX, 0..1000) + ) { + rows.sort(); + let num_rows = rows.len(); + let mask = RowAddrTreeMap::from_sorted_iter(rows).unwrap(); + prop_assert_eq!(mask.len(), Some(num_rows as u64)); + } + + } #[test] diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 32d6e10967c..6e1e1a03985 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -576,6 +576,7 @@ pub struct BTreeLookup { null_pages: Vec, } +#[derive(Debug, Copy, Clone)] enum Matches { Some(u32), All(u32), From 14561d5d062230dfe6c4d97abbde97df7930f599 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Dec 2025 15:09:51 -0800 Subject: [PATCH 5/9] Address clippy suggestions --- rust/lance-index/src/scalar/btree.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 6e1e1a03985..a7a8e29dc0b 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -585,8 +585,8 @@ enum Matches { impl Matches { fn page_id(&self) -> u32 { match self { - Matches::Some(page_id) => *page_id, - Matches::All(page_id) => *page_id, + Self::Some(page_id) => *page_id, + Self::All(page_id) => *page_id, } } } @@ -624,10 +624,7 @@ impl BTreeLookup { } let mut all_pages = heap.into_sorted_vec(); all_pages.dedup(); - all_pages - .into_iter() - .map(|page_id| Matches::Some(page_id)) - .collect() + all_pages.into_iter().map(Matches::Some).collect() } // All pages that could have a value in the range @@ -907,7 +904,7 @@ impl BTreeIndex { serialized_page = frag_reuse_index_ref.remap_row_ids_record_batch(serialized_page, 1)?; } - Ok(FlatIndex::try_new(serialized_page)?) + FlatIndex::try_new(serialized_page) } async fn search_page( From ee17acbc1f4b06758c1c167c031117a2b201a4ed Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Dec 2025 15:22:52 -0800 Subject: [PATCH 6/9] Update python lockfile --- python/Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/python/Cargo.lock b/python/Cargo.lock index f5d78aa41aa..8a0fa929db8 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -3964,6 +3964,7 @@ dependencies = [ "datafusion-sql", "deepsize", "futures", + "itertools 0.13.0", "lance-arrow", "libc", "log", From 6b219b4b0e8951cf3908f9d06ec5e50771c25b71 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Dec 2025 15:40:39 -0800 Subject: [PATCH 7/9] Add null_pages / all_null_pages distinction --- rust/lance-index/src/scalar/btree.rs | 33 ++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index a7a8e29dc0b..fbaa3b8c739 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -572,8 +572,10 @@ impl BTreeMapExt for BTreeMap { #[derive(Debug, DeepSizeOf, PartialEq, Eq)] pub struct BTreeLookup { tree: BTreeMap>, - /// Pages where the value may be null + /// Pages where the value may be null (does not include all_null_pages) null_pages: Vec, + /// Pages that are entirely null + all_null_pages: Vec, } #[derive(Debug, Copy, Clone)] @@ -592,8 +594,16 @@ impl Matches { } impl BTreeLookup { - fn new(tree: BTreeMap>, null_pages: Vec) -> Self { - Self { tree, null_pages } + fn new( + tree: BTreeMap>, + null_pages: Vec, + all_null_pages: Vec, + ) -> Self { + Self { + tree, + null_pages, + all_null_pages, + } } // All pages that could have a value equal to val @@ -759,11 +769,10 @@ impl BTreeLookup { } fn pages_null(&self) -> Vec { - // TODO: We could keep track of all-null pages and return Matches::All for those. - // This would improve performance on data with lots of nulls. self.null_pages .iter() .map(|page_id| Matches::Some(*page_id)) + .chain(self.all_null_pages.iter().copied().map(Matches::All)) .collect() } } @@ -858,13 +867,14 @@ impl BTreeIndex { fn new( tree: BTreeMap>, null_pages: Vec, + all_null_pages: Vec, store: Arc, data_type: DataType, index_cache: WeakLanceCache, batch_size: u64, frag_reuse_index: Option>, ) -> Self { - let page_lookup = Arc::new(BTreeLookup::new(tree, null_pages)); + let page_lookup = Arc::new(BTreeLookup::new(tree, null_pages, all_null_pages)); Self { page_lookup, store, @@ -938,13 +948,17 @@ impl BTreeIndex { frag_reuse_index: Option>, ) -> Result { let mut map = BTreeMap::>::new(); + // Pages that have at least one null value let mut null_pages = Vec::::new(); + // Pages that are entirely null + let mut all_null_pages = Vec::::new(); if data.num_rows() == 0 { let data_type = data.column(0).data_type().clone(); return Ok(Self::new( map, null_pages, + all_null_pages, store, data_type, WeakLanceCache::from(index_cache), @@ -973,7 +987,11 @@ impl BTreeIndex { let page_number = page_numbers.values()[idx]; // If the page is entirely null don't even bother putting it in the tree - if !max.0.is_null() { + if max.0.is_null() { + all_null_pages.push(page_number); + // continue so we don't add it to the null_pages + continue; + } else { map.entry(min) .or_default() .push(PageRecord { max, page_number }); @@ -992,6 +1010,7 @@ impl BTreeIndex { Ok(Self::new( map, null_pages, + all_null_pages, store, data_type.clone(), WeakLanceCache::from(index_cache), From 0fc6f6f4702463bc135d87f1df0e2dd3ef8d7e02 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 10 Dec 2025 05:33:16 -0800 Subject: [PATCH 8/9] Address clippy suggestions --- rust/lance-index/src/scalar/btree.rs | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index fbaa3b8c739..49517e6ae69 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -578,6 +578,16 @@ pub struct BTreeLookup { all_null_pages: Vec, } +impl BTreeLookup { + fn empty() -> Self { + Self { + tree: BTreeMap::new(), + null_pages: Vec::new(), + all_null_pages: Vec::new(), + } + } +} + #[derive(Debug, Copy, Clone)] enum Matches { Some(u32), @@ -865,16 +875,13 @@ impl DeepSizeOf for BTreeIndex { impl BTreeIndex { fn new( - tree: BTreeMap>, - null_pages: Vec, - all_null_pages: Vec, + page_lookup: Arc, store: Arc, data_type: DataType, index_cache: WeakLanceCache, batch_size: u64, frag_reuse_index: Option>, ) -> Self { - let page_lookup = Arc::new(BTreeLookup::new(tree, null_pages, all_null_pages)); Self { page_lookup, store, @@ -955,10 +962,9 @@ impl BTreeIndex { if data.num_rows() == 0 { let data_type = data.column(0).data_type().clone(); + let page_lookup = Arc::new(BTreeLookup::empty()); return Ok(Self::new( - map, - null_pages, - all_null_pages, + page_lookup, store, data_type, WeakLanceCache::from(index_cache), @@ -1007,10 +1013,10 @@ impl BTreeIndex { let data_type = mins.data_type(); + let page_lookup = Arc::new(BTreeLookup::new(map, null_pages, all_null_pages)); + Ok(Self::new( - map, - null_pages, - all_null_pages, + page_lookup, store, data_type.clone(), WeakLanceCache::from(index_cache), From 8b4c11e6a28a9a573f03131d34a145ed183a2858 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 10 Dec 2025 06:54:58 -0800 Subject: [PATCH 9/9] Use DF for expr eval --- rust/lance-index/src/scalar/btree.rs | 6 +- rust/lance-index/src/scalar/btree/flat.rs | 260 +++++++++++++--------- 2 files changed, 157 insertions(+), 109 deletions(-) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 49517e6ae69..2a34a864e78 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -943,7 +943,11 @@ impl BTreeIndex { // to search for X IN [5, 3] subindex.search(query, metrics) } - Matches::All(_) => Ok(subindex.all()), + Matches::All(_) => Ok(match query { + // This means we hit an all-null page so just grab all row ids as true + SargableQuery::IsNull() => subindex.all_ignore_nulls(), + _ => subindex.all(), + }), } } diff --git a/rust/lance-index/src/scalar/btree/flat.rs b/rust/lance-index/src/scalar/btree/flat.rs index 78b139e0002..1da09425528 100644 --- a/rust/lance-index/src/scalar/btree/flat.rs +++ b/rust/lance-index/src/scalar/btree/flat.rs @@ -4,22 +4,27 @@ use std::collections::HashMap; use std::{ops::Bound, sync::Arc}; +use arrow_array::Array; use arrow_array::{ cast::AsArray, types::UInt64Type, ArrayRef, BooleanArray, RecordBatch, UInt64Array, }; -use datafusion_physical_expr::expressions::{in_list, lit, Column}; +use datafusion_common::DFSchema; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_physical_expr::create_physical_expr; use deepsize::DeepSizeOf; use lance_arrow::RecordBatchExt; use lance_core::utils::address::RowAddress; use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; -use lance_core::{Error, Result}; +use lance_core::Result; use roaring::RoaringBitmap; -use snafu::location; use crate::metrics::MetricsCollector; +use crate::scalar::btree::BTREE_VALUES_COLUMN; use crate::scalar::{AnyQuery, SargableQuery}; +const VALUES_COL_IDX: usize = 0; +const IDS_COL_IDX: usize = 1; /// A flat index is just a batch of value/row-id pairs /// /// The batch always has two columns. The first column "values" contains @@ -31,7 +36,7 @@ pub struct FlatIndex { data: Arc, all_addrs_map: RowAddrTreeMap, null_addrs_map: RowAddrTreeMap, - has_nulls: bool, + df_schema: DFSchema, } impl DeepSizeOf for FlatIndex { @@ -43,10 +48,11 @@ impl DeepSizeOf for FlatIndex { impl FlatIndex { pub fn try_new(data: RecordBatch) -> Result { // Sort by row id to make bitmap construction more efficient - let data = data.sort_by_column(1, None)?; - let has_nulls = data.column(1).null_count() > 0; + let data = data.sort_by_column(IDS_COL_IDX, None)?; + + let has_nulls = data.column(VALUES_COL_IDX).null_count() > 0; let all_addrs_map = RowAddrTreeMap::from_sorted_iter( - data.column(1) + data.column(IDS_COL_IDX) .as_primitive::() .values() .iter() @@ -59,20 +65,18 @@ impl FlatIndex { RowAddrTreeMap::default() }; + let df_schema = DFSchema::try_from(data.schema())?; + Ok(Self { data: Arc::new(data), all_addrs_map, null_addrs_map, - has_nulls, + df_schema, }) } - fn values(&self) -> &ArrayRef { - self.data.column(0) - } - fn ids(&self) -> &ArrayRef { - self.data.column(1) + self.data.column(IDS_COL_IDX) } pub fn all(&self) -> NullableRowAddrSet { @@ -80,11 +84,15 @@ impl FlatIndex { NullableRowAddrSet::new(self.all_addrs_map.clone(), self.null_addrs_map.clone()) } + pub fn all_ignore_nulls(&self) -> NullableRowAddrSet { + NullableRowAddrSet::new(self.all_addrs_map.clone(), Default::default()) + } + pub fn remap_batch( batch: RecordBatch, mapping: &HashMap>, ) -> Result { - let row_ids = batch.column(1).as_primitive::(); + let row_ids = batch.column(IDS_COL_IDX).as_primitive::(); let val_idx_and_new_id = row_ids .values() .iter() @@ -105,7 +113,8 @@ impl FlatIndex { .into_iter() .map(|(val_idx, _)| val_idx as u64), ); - let new_vals = arrow_select::take::take(batch.column(0), &new_val_indices, None)?; + let new_vals = + arrow_select::take::take(batch.column(VALUES_COL_IDX), &new_val_indices, None)?; Ok(RecordBatch::try_new( batch.schema(), vec![new_vals, new_ids], @@ -113,8 +122,8 @@ impl FlatIndex { } fn get_null_addrs(sorted_batch: &RecordBatch) -> Result { - let null_mask = arrow::compute::is_null(sorted_batch.column(0))?; - let null_ids = arrow_select::filter::filter(sorted_batch.column(1), &null_mask)?; + let null_mask = arrow::compute::is_null(sorted_batch.column(VALUES_COL_IDX))?; + let null_ids = arrow_select::filter::filter(sorted_batch.column(IDS_COL_IDX), &null_mask)?; let null_ids = null_ids .as_any() .downcast_ref::() @@ -132,121 +141,80 @@ impl FlatIndex { // Since we have all the values in memory we can use basic arrow-rs compute // functions to satisfy scalar queries. - let mut null_is_true = false; - let mut predicate = match query { + // Shortcuts for simple cases where we can re-use computed values + match query { + // x = NULL means all rows are NULL SargableQuery::Equals(value) => { if value.is_null() { - // Query is x = NULL, correct SQL behavior is to return all ids as NULL - // We differ a little and return them all as true right now. + // if we have x = NULL then the correct SQL behavior is to return all NULLs return Ok(NullableRowAddrSet::new( - self.null_addrs_map.clone(), Default::default(), + self.all_addrs_map.clone(), )); - } else { - arrow_ord::cmp::eq(self.values(), &value.to_scalar()?)? } } + // x IS NULL we can use pre-computed nulls SargableQuery::IsNull() => { return Ok(NullableRowAddrSet::new( self.null_addrs_map.clone(), Default::default(), )); } - SargableQuery::IsIn(values) => { - let mut has_null = false; - let choices = values - .iter() - .map(|val| { - has_null |= val.is_null(); - lit(val.clone()) - }) - .collect::>(); - let in_list_expr = in_list( - Arc::new(Column::new("values", 0)), - choices, - &false, - &self.data.schema(), - )?; - let result_col = in_list_expr.evaluate(&self.data)?; - let predicate = result_col - .into_array(self.data.num_rows())? - .as_any() - .downcast_ref::() - .expect("InList evaluation should return boolean array") - .clone(); - - // If the IN query has nulls, then don't treat the nulls as null. This is a little different - // than SQL behavior. - null_is_true = has_null; - - // Arrow's in_list does not handle nulls so we need to join them in here if user asked for them - if has_null && self.has_nulls { - let nulls = arrow::compute::is_null(self.values())?; - arrow::compute::or(&predicate, &nulls)? - } else { - predicate - } - } + // x < NULL or x > NULL means all rows are NULL SargableQuery::Range(lower_bound, upper_bound) => match (lower_bound, upper_bound) { (Bound::Unbounded, Bound::Unbounded) => { - panic!("Scalar range query received with no upper or lower bound") - } - (Bound::Unbounded, Bound::Included(upper)) => { - arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)? - } - (Bound::Unbounded, Bound::Excluded(upper)) => { - arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)? + return Ok(NullableRowAddrSet::new( + self.all_addrs_map.clone(), + Default::default(), + )); } - (Bound::Included(lower), Bound::Unbounded) => { - arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)? + (Bound::Unbounded, Bound::Included(upper) | Bound::Excluded(upper)) => { + if upper.is_null() { + return Ok(NullableRowAddrSet::new( + Default::default(), + self.all_addrs_map.clone(), + )); + } } - (Bound::Included(lower), Bound::Included(upper)) => arrow::compute::and( - &arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?, - &arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?, - )?, - (Bound::Included(lower), Bound::Excluded(upper)) => arrow::compute::and( - &arrow_ord::cmp::gt_eq(self.values(), &lower.to_scalar()?)?, - &arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?, - )?, - (Bound::Excluded(lower), Bound::Unbounded) => { - arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)? + (Bound::Included(lower) | Bound::Excluded(lower), Bound::Unbounded) => { + if lower.is_null() { + return Ok(NullableRowAddrSet::new( + Default::default(), + self.all_addrs_map.clone(), + )); + } } - (Bound::Excluded(lower), Bound::Included(upper)) => arrow::compute::and( - &arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?, - &arrow_ord::cmp::lt_eq(self.values(), &upper.to_scalar()?)?, - )?, - (Bound::Excluded(lower), Bound::Excluded(upper)) => arrow::compute::and( - &arrow_ord::cmp::gt(self.values(), &lower.to_scalar()?)?, - &arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?, - )?, + _ => {} }, - SargableQuery::FullTextSearch(_) => return Err(Error::invalid_input( - "full text search is not supported for flat index, build a inverted index for it", - location!(), - )), - }; - if self.has_nulls && matches!(query, SargableQuery::Range(_, _)) { - // Arrow's comparison kernels do not return false for nulls. They consider nulls to - // be less than any value. So we need to filter out the nulls manually. - let valid_values = arrow::compute::is_not_null(self.values())?; - predicate = arrow::compute::and(&valid_values, &predicate)?; - } - - // Track null row IDs for Kleene logic - // When querying FOR nulls (IS NULL or Equals(null)), don't track them as "null results" - // because they are the TRUE result of the query - let null_row_ids = if null_is_true { - self.null_addrs_map.clone() - } else { - Default::default() + _ => {} }; - let matching_ids = arrow_select::filter::filter(self.ids(), &predicate)?; + // No shortcut possible, need to actually evaluate the query + let expr = query.to_expr(BTREE_VALUES_COLUMN.to_string()); + let expr = create_physical_expr(&expr, &self.df_schema, &ExecutionProps::default())?; + + let predicate = expr.evaluate(&self.data)?; + let predicate = predicate.into_array(self.data.num_rows())?; + let predicate = predicate + .as_any() + .downcast_ref::() + .expect("Predicate should return boolean array"); + let nulls = arrow::compute::is_null(&predicate)?; + + let matching_ids = arrow_select::filter::filter(self.ids(), predicate)?; let matching_ids = matching_ids .as_any() .downcast_ref::() .expect("Result of arrow_select::filter::filter did not match input type"); let selected = RowAddrTreeMap::from_sorted_iter(matching_ids.values().iter().copied())?; + + let null_row_ids = arrow_select::filter::filter(self.ids(), &nulls)?; + let null_row_ids = null_row_ids + .as_any() + .downcast_ref::() + .expect("Result of arrow_select::filter::filter did not match input type"); + let null_row_ids = RowAddrTreeMap::from_sorted_iter(null_row_ids.values().iter().copied())?; + Ok(NullableRowAddrSet::new(selected, null_row_ids)) } @@ -265,10 +233,13 @@ impl FlatIndex { #[cfg(test)] mod tests { - use crate::metrics::NoOpMetricsCollector; + use crate::{ + metrics::NoOpMetricsCollector, + scalar::btree::{BTREE_IDS_COLUMN, BTREE_VALUES_COLUMN}, + }; use super::*; - use arrow_array::types::Int32Type; + use arrow_array::{record_batch, types::Int32Type}; use datafusion_common::ScalarValue; use lance_datagen::{array, gen_batch, RowCount}; @@ -376,4 +347,77 @@ mod tests { let remapped = FlatIndex::remap_batch((*index.data).clone(), &mapping).unwrap(); assert_eq!(remapped.num_rows(), 0); } + + #[test] + fn test_null_handling() { + // [null, 0, 5] + let batch = record_batch!( + (BTREE_VALUES_COLUMN, Int32, [None, Some(0), Some(5)]), + (BTREE_IDS_COLUMN, UInt64, [0, 1, 2]) + ) + .unwrap(); + let index = FlatIndex::try_new(batch).unwrap(); + + let check = |query: SargableQuery, true_ids: &[u64], null_ids: &[u64]| { + let actual = index.search(&query, &NoOpMetricsCollector).unwrap(); + let expected = NullableRowAddrSet::new( + RowAddrTreeMap::from_iter(true_ids), + RowAddrTreeMap::from_iter(null_ids), + ); + assert_eq!(actual, expected, "query: {:?}", query); + }; + + let null = ScalarValue::Int32(None); + let zero = ScalarValue::Int32(Some(0)); + let three = ScalarValue::Int32(Some(3)); + + check(SargableQuery::Equals(zero.clone()), &[1], &[0]); + // x = NULL returns all rows as NULL and nothing as TRUE + check(SargableQuery::Equals(null.clone()), &[], &[0, 1, 2]); + + check(SargableQuery::IsIn(vec![zero.clone()]), &[1], &[0]); + // x IN (0, NULL) promotes all FALSE to NULL + check(SargableQuery::IsIn(vec![zero, null.clone()]), &[1], &[0, 2]); + + check(SargableQuery::IsNull(), &[0], &[]); + + check( + SargableQuery::Range(Bound::Included(three.clone()), Bound::Unbounded), + &[2], + &[0], + ); + + // x < NULL or x > NULL returns everything as NULL + check( + SargableQuery::Range(Bound::Unbounded, Bound::Included(null.clone())), + &[], + &[0, 1, 2], + ); + + check( + SargableQuery::Range(Bound::Excluded(null.clone()), Bound::Unbounded), + &[], + &[0, 1, 2], + ); + + // x BETWEEN 3 AND NULL returns everything as NULL unless we know it is FALSE + check( + SargableQuery::Range( + Bound::Included(three.clone()), + Bound::Included(null.clone()), + ), + &[], + &[0, 2], + ); + check( + SargableQuery::Range(Bound::Included(null.clone()), Bound::Included(three)), + &[], + &[0, 1], + ); + check( + SargableQuery::Range(Bound::Included(null.clone()), Bound::Included(null)), + &[], + &[0, 1, 2], + ); + } }