diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index bc9e7efaf9a..484cbb5cb2a 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -39,6 +39,7 @@ pub mod label_list; pub mod lance_format; pub mod ngram; pub mod registry; +pub mod zoned; pub mod zonemap; use crate::frag_reuse::FragReuseIndex; diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index f29d18a6095..e1ca463143e 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -17,14 +17,9 @@ use crate::scalar::{ }; use crate::{pb, Any}; use arrow_array::{Array, UInt64Array}; -use lance_core::utils::address::RowAddress; -use lance_core::utils::mask::RowAddrTreeMap; -use lance_core::ROW_ADDR; -use lance_datafusion::chunker::chunk_concat_stream; mod as_bytes; mod sbbf; use arrow_schema::{DataType, Field}; -use futures::TryStreamExt; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; @@ -45,34 +40,18 @@ use lance_core::Result; use roaring::RoaringBitmap; use snafu::location; +use super::zoned::{rebuild_zones, search_zones, ZoneBound, ZoneProcessor, ZoneTrainer}; + const BLOOMFILTER_FILENAME: &str = "bloomfilter.lance"; const BLOOMFILTER_ITEM_META_KEY: &str = "bloomfilter_item"; const BLOOMFILTER_PROBABILITY_META_KEY: &str = "bloomfilter_probability"; const BLOOMFILTER_INDEX_VERSION: u32 = 0; -// -// Example: Suppose we have two fragments, each with 4 rows. -// Fragment 0: zone_start = 0, zone_length = 4 // covers rows 0, 1, 2, 3 in fragment 0 -// The row addresses for fragment 0 are: 0, 1, 2, 3 -// Fragment 1: zone_start = 0, zone_length = 4 // covers rows 0, 1, 2, 3 in fragment 1 -// The row addresses for fragment 1 are: 32>>1, 32>>1 + 1, 32>>1 + 2, 32>>1 + 3 -// -// Deletion is 0 index based. We delete the 0th and 1st row in fragment 0, -// and the 1st and 2nd row in fragment 1, -// Fragment 0: zone_start = 2, zone_length = 2 // covers rows 2, 3 in fragment 0 -// The row addresses for fragment 0 are: 2, 3 -// Fragment 1: zone_start = 0, zone_length = 4 // covers rows 0, 3 in fragment 1 -// The row addresses for fragment 1 are: 32>>1, 32>>1 + 3 #[derive(Debug, Clone)] struct BloomFilterStatistics { - fragment_id: u64, - // zone_start is start row of the zone in the fragment, also known - // as the local offset. To get the actual first row address, - // you can do `fragment_id << 32 + zone_start` - zone_start: u64, - // zone_length is the `row offset span` between the first and the last row in the current SBBF block - // calculated as: (last_row_offset - first_row_offset + 1) - zone_length: usize, + // Bound of this zone within the fragment. Persisted as three separate columns + // (fragment_id, zone_start, zone_length) in the index file. + bound: ZoneBound, // Whether this zone contains any null values has_null: bool, // The actual bloom filter (SBBF) for efficient querying @@ -88,6 +67,12 @@ impl DeepSizeOf for BloomFilterStatistics { } } +impl AsRef for BloomFilterStatistics { + fn as_ref(&self) -> &ZoneBound { + &self.bound + } +} + #[derive(Debug, Clone)] pub struct BloomFilterIndex { zones: Vec, @@ -246,9 +231,11 @@ impl BloomFilterIndex { })?; blocks.push(BloomFilterStatistics { - fragment_id: fragment_id_col.value(i), - zone_start: zone_start_col.value(i), - zone_length: zone_length_col.value(i) as usize, + bound: ZoneBound { + fragment_id: fragment_id_col.value(i), + start: zone_start_col.value(i), + length: zone_length_col.value(i) as usize, + }, has_null: has_null_col.value(i), bloom_filter, }); @@ -464,7 +451,7 @@ impl Index for BloomFilterIndex { // Loop through zones and add unique fragment IDs to the bitmap for block in &self.zones { - frag_ids.insert(block.fragment_id as u32); + frag_ids.insert(block.bound.fragment_id as u32); } Ok(frag_ids) @@ -478,23 +465,10 @@ impl ScalarIndex for BloomFilterIndex { query: &dyn AnyQuery, metrics: &dyn MetricsCollector, ) -> Result { - metrics.record_comparisons(self.zones.len()); let query = query.as_any().downcast_ref::().unwrap(); - - let mut row_addr_tree_map = RowAddrTreeMap::new(); - - // For each zone, check if it might contain the queried value - for block in self.zones.iter() { - if self.evaluate_block_against_query(block, query)? { - let zone_start_addr = (block.fragment_id << 32) + block.zone_start; - let zone_end_addr = zone_start_addr + block.zone_length as u64; - - // Add all row addresses in this zone to the result - row_addr_tree_map.insert_range(zone_start_addr..zone_end_addr); - } - } - - Ok(SearchResult::AtMost(row_addr_tree_map)) + search_zones(&self.zones, metrics, |block| { + self.evaluate_block_against_query(block, query) + }) } fn can_remap(&self) -> bool { @@ -517,33 +491,20 @@ impl ScalarIndex for BloomFilterIndex { new_data: SendableRecordBatchStream, dest_store: &dyn IndexStore, ) -> Result { - // 1. Prepare the builder for new bloom filters - let batches_source = new_data; - - let mut builder = BloomFilterIndexBuilder::try_new(BloomFilterIndexBuilderParams { + // Re-train bloom filters for the appended data using the shared trainer + let params = BloomFilterIndexBuilderParams { number_of_items: self.number_of_items, probability: self.probability, - })?; - - builder.train(batches_source).await?; - - // Get the new blocks from the builder - let new_blocks = builder.blocks; - - // Combine existing zones with new zones - let mut all_blocks = self.zones.clone(); - all_blocks.extend(new_blocks); + }; - // Create a new builder with all blocks to write them out - let mut combined_builder = - BloomFilterIndexBuilder::try_new(BloomFilterIndexBuilderParams { - number_of_items: self.number_of_items, - probability: self.probability, - })?; - combined_builder.blocks = all_blocks; + let processor = BloomFilterProcessor::new(params.clone())?; + let trainer = ZoneTrainer::new(processor, params.number_of_items)?; + let updated_blocks = rebuild_zones(&self.zones, trainer, new_data).await?; - // Write the updated index to dest_store - combined_builder.write_index(dest_store).await?; + // Write the combined zones back to storage + let mut builder = BloomFilterIndexBuilder::try_new(params)?; + builder.blocks = updated_blocks; + builder.write_index(dest_store).await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pb::BloomFilterIndexDetails::default()) @@ -631,38 +592,129 @@ impl BloomFilterIndexBuilderParams { pub struct BloomFilterIndexBuilder { params: BloomFilterIndexBuilderParams, blocks: Vec, - // The local offset within the current zones - cur_zone_offset: usize, - cur_fragment_id: u32, - // Track the actual first and last row offsets in the current zone - // This handles non-contiguous offsets after deletions - cur_zone_first_row_offset: Option, - cur_zone_last_row_offset: Option, - cur_zone_has_null: bool, - sbbf: Option, } impl BloomFilterIndexBuilder { pub fn try_new(params: BloomFilterIndexBuilderParams) -> Result { - let sbbf = SbbfBuilder::new() + Ok(Self { + params, + blocks: Vec::new(), + }) + } + + /// Train the builder using the shared ZoneTrainer. The input stream is expected to + /// contain the value column followed by `_rowaddr`, matching the order emitted by + /// the scalar index training pipeline. + pub async fn train(&mut self, batches_source: SendableRecordBatchStream) -> Result<()> { + let processor = BloomFilterProcessor::new(self.params.clone())?; + let trainer = ZoneTrainer::new(processor, self.params.number_of_items)?; + self.blocks = trainer.train(batches_source).await?; + Ok(()) + } + + fn bloomfilter_stats_as_batch(&self) -> Result { + let fragment_ids = + UInt64Array::from_iter_values(self.blocks.iter().map(|block| block.bound.fragment_id)); + + let zone_starts = + UInt64Array::from_iter_values(self.blocks.iter().map(|block| block.bound.start)); + + let zone_lengths = UInt64Array::from_iter_values( + self.blocks.iter().map(|block| block.bound.length as u64), + ); + + let has_nulls = arrow_array::BooleanArray::from( + self.blocks + .iter() + .map(|block| block.has_null) + .collect::>(), + ); + + // Convert bloom filters to binary data for serialization + let bloom_filter_data = if self.blocks.is_empty() { + Arc::new(arrow_array::BinaryArray::new_null(0)) as ArrayRef + } else { + let binary_data: Vec> = self + .blocks + .iter() + .map(|block| block.bloom_filter.to_bytes()) + .collect(); + let binary_refs: Vec> = binary_data + .iter() + .map(|bytes| Some(bytes.as_slice())) + .collect(); + Arc::new(arrow_array::BinaryArray::from_opt_vec(binary_refs)) as ArrayRef + }; + + let schema = Arc::new(arrow_schema::Schema::new(vec![ + Field::new("fragment_id", DataType::UInt64, false), + Field::new("zone_start", DataType::UInt64, false), + Field::new("zone_length", DataType::UInt64, false), + Field::new("has_null", DataType::Boolean, false), + Field::new("bloom_filter_data", DataType::Binary, false), + ])); + + let columns: Vec = vec![ + Arc::new(fragment_ids) as ArrayRef, + Arc::new(zone_starts) as ArrayRef, + Arc::new(zone_lengths) as ArrayRef, + Arc::new(has_nulls) as ArrayRef, + bloom_filter_data, + ]; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + pub async fn write_index(self, index_store: &dyn IndexStore) -> Result<()> { + let record_batch = self.bloomfilter_stats_as_batch()?; + + let mut file_schema = record_batch.schema().as_ref().clone(); + file_schema.metadata.insert( + BLOOMFILTER_ITEM_META_KEY.to_string(), + self.params.number_of_items.to_string(), + ); + + file_schema.metadata.insert( + BLOOMFILTER_PROBABILITY_META_KEY.to_string(), + self.params.probability.to_string(), + ); + + let mut index_file = index_store + .new_index_file(BLOOMFILTER_FILENAME, Arc::new(file_schema)) + .await?; + index_file.write_record_batch(record_batch).await?; + index_file.finish().await?; + Ok(()) + } +} + +/// Index-specific processor that inserts values into the split block Bloom filter. +struct BloomFilterProcessor { + params: BloomFilterIndexBuilderParams, + sbbf: Option, + cur_zone_has_null: bool, +} + +impl BloomFilterProcessor { + fn new(params: BloomFilterIndexBuilderParams) -> Result { + let mut processor = Self { + params, + sbbf: None, + cur_zone_has_null: false, + }; + processor.reset()?; + Ok(processor) + } + + fn build_filter(params: &BloomFilterIndexBuilderParams) -> Result { + SbbfBuilder::new() .expected_items(params.number_of_items) .false_positive_probability(params.probability) .build() .map_err(|e| Error::InvalidInput { source: format!("Failed to build SBBF: {:?}", e).into(), location: location!(), - })?; - - Ok(Self { - params, - blocks: Vec::new(), - cur_zone_offset: 0, - cur_fragment_id: 0, - cur_zone_first_row_offset: None, - cur_zone_last_row_offset: None, - cur_zone_has_null: false, - sbbf: Some(sbbf), - }) + }) } fn process_primitive_array(sbbf: &mut Sbbf, array: &arrow_array::PrimitiveArray) -> bool @@ -728,446 +780,245 @@ impl BloomFilterIndexBuilder { } has_null } +} - fn update_stats(&mut self, array: &ArrayRef) -> Result<()> { - if let Some(ref mut sbbf) = self.sbbf { - let has_null = match array.data_type() { - // Signed integers - DataType::Int8 => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - DataType::Int16 => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - DataType::Int32 => { +impl ZoneProcessor for BloomFilterProcessor { + type ZoneStatistics = BloomFilterStatistics; + + fn process_chunk(&mut self, array: &ArrayRef) -> Result<()> { + let sbbf = self.sbbf.as_mut().ok_or_else(|| { + Error::invalid_input( + "BloomFilterProcessor did not initialize bloom filter", + location!(), + ) + })?; + + let has_null = match array.data_type() { + // Signed integers + DataType::Int8 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::Int16 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::Int32 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::Int64 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + // Unsigned integers + DataType::UInt8 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::UInt16 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::UInt32 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::UInt64 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + // Floating point numbers + DataType::Float32 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::Float64 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + // Date and time types (stored as i32 internally) + DataType::Date32 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::Time32(time_unit) => match time_unit { + arrow_schema::TimeUnit::Second => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - DataType::Int64 => { + arrow_schema::TimeUnit::Millisecond => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - // Unsigned integers - DataType::UInt8 => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) + _ => { + return Err(Error::InvalidInput { + source: format!("Unsupported Time32 unit: {:?}", time_unit).into(), + location: location!(), + }); } - DataType::UInt16 => { + }, + // Date and time types (stored as i64 internally) + DataType::Date64 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_primitive_array(sbbf, typed_array) + } + DataType::Time64(time_unit) => match time_unit { + arrow_schema::TimeUnit::Microsecond => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - DataType::UInt32 => { + arrow_schema::TimeUnit::Nanosecond => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - DataType::UInt64 => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) + _ => { + return Err(Error::InvalidInput { + source: format!("Unsupported Time64 unit: {:?}", time_unit).into(), + location: location!(), + }); } - // Floating point numbers - DataType::Float32 => { + }, + DataType::Timestamp(time_unit, _) => match time_unit { + arrow_schema::TimeUnit::Second => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - DataType::Float64 => { + arrow_schema::TimeUnit::Millisecond => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - // Date and time types (stored as i32 internally) - DataType::Date32 => { + arrow_schema::TimeUnit::Microsecond => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - DataType::Time32(time_unit) => match time_unit { - arrow_schema::TimeUnit::Second => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - arrow_schema::TimeUnit::Millisecond => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - _ => { - return Err(Error::InvalidInput { - source: format!("Unsupported Time32 unit: {:?}", time_unit).into(), - location: location!(), - }); - } - }, - // Date and time types (stored as i64 internally) - DataType::Date64 => { + arrow_schema::TimeUnit::Nanosecond => { let typed_array = array .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); Self::process_primitive_array(sbbf, typed_array) } - DataType::Time64(time_unit) => match time_unit { - arrow_schema::TimeUnit::Microsecond => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - arrow_schema::TimeUnit::Nanosecond => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - _ => { - return Err(Error::InvalidInput { - source: format!("Unsupported Time64 unit: {:?}", time_unit).into(), - location: location!(), - }); - } - }, - DataType::Timestamp(time_unit, _) => match time_unit { - arrow_schema::TimeUnit::Second => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - arrow_schema::TimeUnit::Millisecond => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - arrow_schema::TimeUnit::Microsecond => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - arrow_schema::TimeUnit::Nanosecond => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_primitive_array(sbbf, typed_array) - } - }, - DataType::Utf8 => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_string_array(sbbf, typed_array) - } - DataType::LargeUtf8 => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_large_string_array(sbbf, typed_array) - } - DataType::Binary => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_binary_array(sbbf, typed_array) - } - DataType::LargeBinary => { - let typed_array = array - .as_any() - .downcast_ref::() - .unwrap(); - Self::process_large_binary_array(sbbf, typed_array) - } - _ => { - return Err(Error::InvalidInput { - source: format!( - "Bloom filter does not support data type: {:?}", - array.data_type() - ) - .into(), - location: location!(), - }); - } - }; - - // Update the current zone's null tracking - self.cur_zone_has_null = self.cur_zone_has_null || has_null; - } - - Ok(()) - } - - fn new_block(&mut self, fragment_id: u32) -> Result<()> { - let zone_start = self.cur_zone_first_row_offset.unwrap_or(0) as u64; - let zone_length = self - .cur_zone_last_row_offset - .map(|last_row_offset| { - (last_row_offset - self.cur_zone_first_row_offset.unwrap_or(0) + 1) as usize - }) - .unwrap_or(self.cur_zone_offset); - - // Store the current bloom filter directly - let bloom_filter = if let Some(ref sbbf) = self.sbbf { - sbbf.clone() - } else { - // Create a default empty bloom filter - SbbfBuilder::new() - .expected_items(self.params.number_of_items) - .false_positive_probability(self.params.probability) - .build() - .map_err(|e| Error::InvalidInput { - source: format!("Failed to build default SBBF: {:?}", e).into(), - location: location!(), - })? - }; - - let new_block = BloomFilterStatistics { - fragment_id: fragment_id as u64, - zone_start, - zone_length, - has_null: self.cur_zone_has_null, - bloom_filter, - }; - - self.blocks.push(new_block); - self.cur_zone_offset = 0; - self.cur_zone_first_row_offset = None; - self.cur_zone_last_row_offset = None; - self.cur_zone_has_null = false; - - // Reset sbbf for the next block - self.sbbf = Some( - SbbfBuilder::new() - .expected_items(self.params.number_of_items) - .false_positive_probability(self.params.probability) - .build() - .map_err(|e| Error::InvalidInput { - source: format!("Failed to build SBBF: {:?}", e).into(), - location: location!(), - })?, - ); - - Ok(()) - } - - pub async fn train(&mut self, batches_source: SendableRecordBatchStream) -> Result<()> { - assert!(batches_source.schema().field_with_name(ROW_ADDR).is_ok()); - - let mut batches_source = - chunk_concat_stream(batches_source, self.params.number_of_items as usize); - - while let Some(batch) = batches_source.try_next().await? { - if batch.num_rows() == 0 { - continue; + }, + DataType::Utf8 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_string_array(sbbf, typed_array) } - - let data_array: &arrow_array::ArrayRef = batch.column(0); - let row_addrs_array = batch - .column_by_name(ROW_ADDR) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - - let mut remaining = batch.num_rows(); - let mut array_offset: usize = 0; - - // Initialize cur_fragment_id from the first row address if this is the first batch - if self.blocks.is_empty() && self.cur_zone_offset == 0 { - let first_row_addr = row_addrs_array.value(0); - self.cur_fragment_id = (first_row_addr >> 32) as u32; + DataType::LargeUtf8 => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_large_string_array(sbbf, typed_array) } - - while remaining > 0 { - // Find the next fragment boundary in this batch - let next_fragment_index = (array_offset..row_addrs_array.len()).find(|&i| { - let row_addr = row_addrs_array.value(i); - let fragment_id = (row_addr >> 32) as u32; - fragment_id == self.cur_fragment_id + 1 - }); - let empty_rows_left_in_cur_zone: usize = - (self.params.number_of_items - self.cur_zone_offset as u64) as usize; - - // Check if there is enough data from the current fragment to fill the current zone - let desired = if let Some(idx) = next_fragment_index { - self.cur_fragment_id = (row_addrs_array.value(idx) >> 32) as u32; - // Take the minimum between distance to boundary and space left in zone - // to ensure we don't exceed the zone size limit - std::cmp::min(idx - array_offset, empty_rows_left_in_cur_zone) - } else { - empty_rows_left_in_cur_zone - }; - - if desired > remaining { - // Not enough data to fill a map, just increment counts - self.update_stats(&data_array.slice(array_offset, remaining))?; - - let first_row_offset = - RowAddress::new_from_u64(row_addrs_array.value(array_offset)).row_offset(); - let last_row_offset = RowAddress::new_from_u64( - row_addrs_array.value(array_offset + remaining - 1), + DataType::Binary => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_binary_array(sbbf, typed_array) + } + DataType::LargeBinary => { + let typed_array = array + .as_any() + .downcast_ref::() + .unwrap(); + Self::process_large_binary_array(sbbf, typed_array) + } + _ => { + return Err(Error::InvalidInput { + source: format!( + "Bloom filter does not support data type: {:?}", + array.data_type() ) - .row_offset(); - if self.cur_zone_first_row_offset.is_none() { - self.cur_zone_first_row_offset = Some(first_row_offset); - } - self.cur_zone_last_row_offset = Some(last_row_offset); - - self.cur_zone_offset += remaining; - break; - } else if desired > 0 { - // There is enough data, create a new zone - self.update_stats(&data_array.slice(array_offset, desired))?; - - let first_row_offset = - RowAddress::new_from_u64(row_addrs_array.value(array_offset)).row_offset(); - let last_row_offset = - RowAddress::new_from_u64(row_addrs_array.value(array_offset + desired - 1)) - .row_offset(); - if self.cur_zone_first_row_offset.is_none() { - self.cur_zone_first_row_offset = Some(first_row_offset); - } - self.cur_zone_last_row_offset = Some(last_row_offset); - - self.cur_zone_offset += desired; - self.new_block((row_addrs_array.value(array_offset) >> 32) as u32)?; - } else if desired == 0 { - // The new batch starts with a new fragment. Flush the current zone if it's not empty - if self.cur_zone_offset > 0 { - self.new_block(self.cur_fragment_id.wrapping_sub(1))?; - } - // Let the loop run again - // to find the next fragment boundary - continue; - } - array_offset += desired; - remaining = remaining.saturating_sub(desired); + .into(), + location: location!(), + }); } - } - // Create the final zone - if self.cur_zone_offset > 0 { - self.new_block(self.cur_fragment_id)?; - } + }; + // Update the current zone's null tracking + self.cur_zone_has_null = self.cur_zone_has_null || has_null; Ok(()) } - fn bloomfilter_stats_as_batch(&self) -> Result { - let fragment_ids = - UInt64Array::from_iter_values(self.blocks.iter().map(|block| block.fragment_id)); - - let zone_starts = - UInt64Array::from_iter_values(self.blocks.iter().map(|block| block.zone_start)); - - let zone_lengths = - UInt64Array::from_iter_values(self.blocks.iter().map(|block| block.zone_length as u64)); - - let has_nulls = arrow_array::BooleanArray::from( - self.blocks - .iter() - .map(|block| block.has_null) - .collect::>(), - ); - - // Convert bloom filters to binary data for serialization - let bloom_filter_data = if self.blocks.is_empty() { - Arc::new(arrow_array::BinaryArray::new_null(0)) as ArrayRef - } else { - let binary_data: Vec> = self - .blocks - .iter() - .map(|block| block.bloom_filter.to_bytes()) - .collect(); - let binary_refs: Vec> = binary_data - .iter() - .map(|bytes| Some(bytes.as_slice())) - .collect(); - Arc::new(arrow_array::BinaryArray::from_opt_vec(binary_refs)) as ArrayRef - }; - - let schema = Arc::new(arrow_schema::Schema::new(vec![ - Field::new("fragment_id", DataType::UInt64, false), - Field::new("zone_start", DataType::UInt64, false), - Field::new("zone_length", DataType::UInt64, false), - Field::new("has_null", DataType::Boolean, false), - Field::new("bloom_filter_data", DataType::Binary, false), - ])); - - let columns: Vec = vec![ - Arc::new(fragment_ids) as ArrayRef, - Arc::new(zone_starts) as ArrayRef, - Arc::new(zone_lengths) as ArrayRef, - Arc::new(has_nulls) as ArrayRef, - bloom_filter_data, - ]; - - Ok(RecordBatch::try_new(schema, columns)?) + fn finish_zone(&mut self, bound: ZoneBound) -> Result { + let bloom_filter = self.sbbf.as_ref().ok_or_else(|| { + Error::invalid_input( + "BloomFilterProcessor did not initialize bloom filter", + location!(), + ) + })?; + Ok(BloomFilterStatistics { + bound, + has_null: self.cur_zone_has_null, + bloom_filter: bloom_filter.clone(), + }) } - pub async fn write_index(self, index_store: &dyn IndexStore) -> Result<()> { - let record_batch = self.bloomfilter_stats_as_batch()?; - - let mut file_schema = record_batch.schema().as_ref().clone(); - file_schema.metadata.insert( - BLOOMFILTER_ITEM_META_KEY.to_string(), - self.params.number_of_items.to_string(), - ); - - file_schema.metadata.insert( - BLOOMFILTER_PROBABILITY_META_KEY.to_string(), - self.params.probability.to_string(), - ); - - let mut index_file = index_store - .new_index_file(BLOOMFILTER_FILENAME, Arc::new(file_schema)) - .await?; - index_file.write_record_batch(record_batch).await?; - index_file.finish().await?; + fn reset(&mut self) -> Result<()> { + self.sbbf = Some(Self::build_filter(&self.params)?); + self.cur_zone_has_null = false; Ok(()) } } @@ -1479,9 +1330,9 @@ mod tests { assert_eq!(index.probability, 0.01); // Check that we have one zone (since 100 items fit exactly in one zone of size 100) - assert_eq!(index.zones[0].fragment_id, 0u64); - assert_eq!(index.zones[0].zone_start, 0u64); - assert_eq!(index.zones[0].zone_length, 100); + assert_eq!(index.zones[0].bound.fragment_id, 0u64); + assert_eq!(index.zones[0].bound.start, 0u64); + assert_eq!(index.zones[0].bound.length, 100); // Test search functionality // The bloom filter should work correctly and find the value @@ -1560,22 +1411,22 @@ mod tests { assert_eq!(index.zones.len(), 4); // Check fragment 0 zones - assert_eq!(index.zones[0].fragment_id, 0u64); - assert_eq!(index.zones[0].zone_start, 0u64); - assert_eq!(index.zones[0].zone_length, 50); + assert_eq!(index.zones[0].bound.fragment_id, 0u64); + assert_eq!(index.zones[0].bound.start, 0u64); + assert_eq!(index.zones[0].bound.length, 50); - assert_eq!(index.zones[1].fragment_id, 0u64); - assert_eq!(index.zones[1].zone_start, 50u64); - assert_eq!(index.zones[1].zone_length, 50); + assert_eq!(index.zones[1].bound.fragment_id, 0u64); + assert_eq!(index.zones[1].bound.start, 50u64); + assert_eq!(index.zones[1].bound.length, 50); // Check fragment 1 zones - assert_eq!(index.zones[2].fragment_id, 1u64); - assert_eq!(index.zones[2].zone_start, 0u64); - assert_eq!(index.zones[2].zone_length, 50); + assert_eq!(index.zones[2].bound.fragment_id, 1u64); + assert_eq!(index.zones[2].bound.start, 0u64); + assert_eq!(index.zones[2].bound.length, 50); - assert_eq!(index.zones[3].fragment_id, 1u64); - assert_eq!(index.zones[3].zone_start, 50u64); - assert_eq!(index.zones[3].zone_length, 50); + assert_eq!(index.zones[3].bound.fragment_id, 1u64); + assert_eq!(index.zones[3].bound.start, 50u64); + assert_eq!(index.zones[3].bound.length, 50); // Test search functionality let query = BloomFilterQuery::Equals(ScalarValue::Int64(Some(150))); @@ -1736,9 +1587,9 @@ mod tests { // Verify zone structure for (i, block) in index.zones.iter().enumerate() { - assert_eq!(block.fragment_id, 0u64); - assert_eq!(block.zone_start, (i * 1000) as u64); - assert_eq!(block.zone_length, 1000); + assert_eq!(block.bound.fragment_id, 0u64); + assert_eq!(block.bound.start, (i * 1000) as u64); + assert_eq!(block.bound.length, 1000); // Check that the bloom filter has some data (non-zero bytes when serialized) assert!(!block.bloom_filter.to_bytes().is_empty()); } diff --git a/rust/lance-index/src/scalar/zoned.rs b/rust/lance-index/src/scalar/zoned.rs new file mode 100644 index 00000000000..02ef1098ee0 --- /dev/null +++ b/rust/lance-index/src/scalar/zoned.rs @@ -0,0 +1,858 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Shared Zone Training Utilities +//! +//! This module provides common infrastructure for building zone-based scalar indexes. +//! It handles chunking data streams into fixed-size zones while respecting fragment +//! boundaries and computing zone bounds that remain valid after row deletions. + +use arrow_array::{ArrayRef, UInt64Array}; +use datafusion::execution::SendableRecordBatchStream; +use futures::TryStreamExt; +use lance_core::error::Error; +use lance_core::utils::address::RowAddress; +use lance_core::utils::mask::RowAddrTreeMap; +use lance_core::{Result, ROW_ADDR}; +use lance_datafusion::chunker::chunk_concat_stream; +use snafu::location; + +// +// Example: Suppose we have two fragments, each with 4 rows. +// Fragment 0: start = 0, length = 4 // covers rows 0, 1, 2, 3 in fragment 0 +// The row addresses for fragment 0 are: 0, 1, 2, 3 +// Fragment 1: start = 0, length = 4 // covers rows 0, 1, 2, 3 in fragment 1 +// The row addresses for fragment 1 are: (1<<32), (1<<32)+1, (1<<32)+2, (1<<32)+3 +// +// Deletion is 0 index based. We delete the 0th and 1st row in fragment 0, +// and the 1st and 2nd row in fragment 1, +// Fragment 0: start = 2, length = 2 // covers rows 2, 3 in fragment 0 +// The row addresses for fragment 0 are: 2, 3 +// Fragment 1: start = 0, length = 4 // covers rows 0, 3 in fragment 1 +// The row addresses for fragment 1 are: (1<<32), (1<<32)+3 +/// Zone bound within a fragment +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ZoneBound { + pub fragment_id: u64, + // start is start row of the zone in the fragment, also known + // as the local offset. To get the actual first row address, + // use `(fragment_id << 32) | start`. + pub start: u64, + // length is the span of row offsets between the first and last row in the zone, + // calculated as (last_row_offset - first_row_offset + 1). It is not the count + // of physical rows, since deletions may create gaps within the span. + pub length: usize, +} + +/// Index-specific logic used while building zones. +pub trait ZoneProcessor { + type ZoneStatistics; + + /// Process a slice of values that belongs to the current zone. + fn process_chunk(&mut self, values: &ArrayRef) -> Result<()>; + + /// Emit statistics when the zone is full or the fragment changes. + fn finish_zone(&mut self, bound: ZoneBound) -> Result; + + /// Reset state so the processor can handle the next zone. + fn reset(&mut self) -> Result<()>; +} + +/// Trainer that handles chunking, fragment boundaries, and zone flushing. +#[derive(Debug)] +pub struct ZoneTrainer

{ + processor: P, + zone_capacity: u64, +} + +impl

ZoneTrainer

+where + P: ZoneProcessor, +{ + /// Create a new trainer that buffers at most `zone_capacity` rows per zone. + pub fn new(processor: P, zone_capacity: u64) -> Result { + if zone_capacity == 0 { + return Err(Error::invalid_input( + "zone capacity must be greater than zero", + location!(), + )); + } + Ok(Self { + processor, + zone_capacity, + }) + } + + /// Consume the `_rowaddr`-annotated stream, split it into zones, and let the + /// processor compute zone statistics. + /// + /// The caller must provide record batches where the first column is the + /// value array that the zone processor understands, and the schema includes + /// the `_rowaddr` column with physical row addresses. Future zone-based + /// indexes should maintain this ordering or extend the trainer to accept an + /// explicit column index. + pub async fn train( + mut self, + stream: SendableRecordBatchStream, + ) -> Result> { + let zone_size = usize::try_from(self.zone_capacity).map_err(|_| { + Error::invalid_input( + "zone capacity does not fit into usize on this platform", + location!(), + ) + })?; + + let mut batches = chunk_concat_stream(stream, zone_size); + let mut zones = Vec::new(); + let mut current_fragment_id: Option = None; + let mut current_zone_len: usize = 0; + let mut zone_start_offset: Option = None; + let mut zone_end_offset: Option = None; + + self.processor.reset()?; + + while let Some(batch) = batches.try_next().await? { + if batch.num_rows() == 0 { + continue; + } + + let values = batch.column(0); + let row_addr_col = batch + .column_by_name(ROW_ADDR) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let mut batch_offset = 0usize; + while batch_offset < batch.num_rows() { + let row_addr = row_addr_col.value(batch_offset); + let fragment_id = row_addr >> 32; + + // Zones cannot span fragments; flush current zone (if non-empty) at boundary + match current_fragment_id { + Some(current) if current != fragment_id => { + if current_zone_len > 0 { + Self::flush_zone( + &mut self.processor, + &mut zones, + current, + &mut current_zone_len, + &mut zone_start_offset, + &mut zone_end_offset, + )?; + } + current_fragment_id = Some(fragment_id); + } + None => { + current_fragment_id = Some(fragment_id); + } + _ => {} + } + + // Count consecutive rows in the same fragment + let run_len = (batch_offset..batch.num_rows()) + .take_while(|&idx| (row_addr_col.value(idx) >> 32) == fragment_id) + .count(); + let capacity = zone_size - current_zone_len; + let take = run_len.min(capacity); + + self.processor + .process_chunk(&values.slice(batch_offset, take))?; + + // Track the first and last row offsets to handle non-contiguous offsets + // after deletions. Zone length (offset span) is computed as (last - first + 1), + // not the actual row count. + let first_offset = + RowAddress::new_from_u64(row_addr_col.value(batch_offset)).row_offset() as u64; + let last_offset = + RowAddress::new_from_u64(row_addr_col.value(batch_offset + take - 1)) + .row_offset() as u64; + + if zone_start_offset.is_none() { + zone_start_offset = Some(first_offset); + } + zone_end_offset = Some(last_offset); + + current_zone_len += take; + batch_offset += take; + + if current_zone_len == zone_size { + Self::flush_zone( + &mut self.processor, + &mut zones, + fragment_id, + &mut current_zone_len, + &mut zone_start_offset, + &mut zone_end_offset, + )?; + } + } + } + + if current_zone_len > 0 { + if let Some(fragment_id) = current_fragment_id { + Self::flush_zone( + &mut self.processor, + &mut zones, + fragment_id, + &mut current_zone_len, + &mut zone_start_offset, + &mut zone_end_offset, + )?; + } else { + self.processor.reset()?; + } + } + + Ok(zones) + } + + /// Flushes a non-empty zone and resets the processor state. + fn flush_zone( + processor: &mut P, + zones: &mut Vec, + fragment_id: u64, + current_zone_len: &mut usize, + zone_start_offset: &mut Option, + zone_end_offset: &mut Option, + ) -> Result<()> { + let start = zone_start_offset.unwrap_or(0); + let inferred_end = + zone_end_offset.unwrap_or_else(|| start + (*current_zone_len as u64).saturating_sub(1)); + if inferred_end < start { + return Err(Error::invalid_input( + "zone row offsets are out of order", + location!(), + )); + } + let bound = ZoneBound { + fragment_id, + start, + length: (inferred_end - start + 1) as usize, + }; + let stats = processor.finish_zone(bound)?; + zones.push(stats); + *current_zone_len = 0; + *zone_start_offset = None; + *zone_end_offset = None; + processor.reset()?; + Ok(()) + } +} + +/// Shared search helper that loops over zones, records metrics, and +/// collects row address ranges for matching zones. The result is always +/// returned as `SearchResult::AtMost` because zone-level pruning can only +/// guarantee a superset of the true matches. +pub fn search_zones( + zones: &[T], + metrics: &dyn crate::metrics::MetricsCollector, + mut zone_matches: F, +) -> Result +where + T: AsRef, + F: FnMut(&T) -> Result, +{ + metrics.record_comparisons(zones.len()); + let mut row_addr_tree_map = RowAddrTreeMap::new(); + + // For each zone, check if it might contain the queried value + for zone in zones { + if zone_matches(zone)? { + let bound = zone.as_ref(); + // Calculate the range of row addresses for this zone + let zone_start_addr = (bound.fragment_id << 32) + bound.start; + let zone_end_addr = zone_start_addr + bound.length as u64; + + // Add all row addresses in this zone to the result + row_addr_tree_map.insert_range(zone_start_addr..zone_end_addr); + } + } + + Ok(crate::scalar::SearchResult::AtMost(row_addr_tree_map)) +} + +/// Helper that retrains zones from `stream` and appends them to the existing +/// statistics. Useful for index update paths that need to merge new fragments +/// into an existing zone list. +pub async fn rebuild_zones

( + existing: &[P::ZoneStatistics], + trainer: ZoneTrainer

, + stream: SendableRecordBatchStream, +) -> Result> +where + P: ZoneProcessor, + P::ZoneStatistics: Clone, +{ + let mut combined = existing.to_vec(); + let mut new_zones = trainer.train(stream).await?; + combined.append(&mut new_zones); + Ok(combined) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{metrics::LocalMetricsCollector, scalar::SearchResult}; + use arrow_array::{ArrayRef, Int32Array, RecordBatch, UInt64Array}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use futures::stream; + use lance_core::ROW_ADDR; + use std::sync::Arc; + + #[derive(Debug, Clone, PartialEq)] + struct MockStats { + sum: i32, + bound: ZoneBound, + } + + #[derive(Debug)] + struct MockProcessor { + current_sum: i32, + } + + impl MockProcessor { + fn new() -> Self { + Self { current_sum: 0 } + } + } + + impl ZoneProcessor for MockProcessor { + type ZoneStatistics = MockStats; + + fn process_chunk(&mut self, values: &ArrayRef) -> Result<()> { + let arr = values.as_any().downcast_ref::().unwrap(); + self.current_sum += arr.iter().map(|v| v.unwrap_or(0)).sum::(); + Ok(()) + } + + fn finish_zone(&mut self, bound: ZoneBound) -> Result { + Ok(MockStats { + sum: self.current_sum, + bound, + }) + } + + fn reset(&mut self) -> Result<()> { + self.current_sum = 0; + Ok(()) + } + } + + fn batch(values: Vec, fragments: Vec, offsets: Vec) -> RecordBatch { + let val_array = Arc::new(Int32Array::from(values)); + let row_addrs: Vec = fragments + .into_iter() + .zip(offsets) + .map(|(frag, off)| (frag << 32) | off) + .collect(); + let addr_array = Arc::new(UInt64Array::from(row_addrs)); + let schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int32, false), + Field::new(ROW_ADDR, DataType::UInt64, false), + ])); + RecordBatch::try_new(schema, vec![val_array, addr_array]).unwrap() + } + + #[tokio::test] + async fn splits_single_fragment() { + // Single fragment with 10 rows, zone capacity = 4. + // Expect three zones with lengths [4, 4, 2]. + let values = vec![1; 10]; + let offsets: Vec = (0..10).collect(); + let batch = batch(values, vec![0; 10], offsets); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 4).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // Three zones: offsets [0..=3], [4..=7], [8..=9] + assert_eq!(stats.len(), 3); + assert_eq!(stats[0].bound.start, 0); + assert_eq!(stats[0].bound.length, 4); + assert_eq!(stats[1].bound.start, 4); + assert_eq!(stats[1].bound.length, 4); + assert_eq!(stats[2].bound.start, 8); + assert_eq!(stats[2].bound.length, 2); // Last zone has only 2 rows + assert_eq!( + stats.iter().map(|s| s.sum).collect::>(), + vec![4, 4, 2] + ); + } + + #[tokio::test] + async fn flushes_on_fragment_boundary() { + // Two fragments back to back, capacity is large enough that only fragment + // boundaries cause zone flushes. Expect two zones (one per fragment). + let values = vec![1, 1, 1, 2, 2, 2]; + let fragments = vec![0, 0, 0, 1, 1, 1]; + let offsets = vec![0, 1, 2, 0, 1, 2]; + let batch = batch(values, fragments, offsets); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 10).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // Two zones, one per fragment (capacity=10 is large enough) + assert_eq!(stats.len(), 2); + assert_eq!(stats[0].bound.fragment_id, 0); + assert_eq!(stats[0].bound.length, 3); // Fragment 0: offsets 0,1,2 → length = 2-0+1 = 3 + assert_eq!(stats[1].bound.fragment_id, 1); + assert_eq!(stats[1].bound.length, 3); // Fragment 1: offsets 0,1,2 → length = 2-0+1 = 3 + } + + #[tokio::test] + async fn errors_on_out_of_order_offsets() { + // Offsets go backwards (5 -> 3). Trainer should treat this as invalid input + // rather than silently emitting a zero-length zone. + let values = vec![1, 2, 3]; + let fragments = vec![0, 0, 0]; + let offsets = vec![5, 3, 4]; + let batch = batch(values, fragments, offsets); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 10).unwrap(); + let err = trainer.train(stream).await.unwrap_err(); + assert!( + format!("{}", err).contains("zone row offsets are out of order"), + "unexpected error: {err:?}" + ); + } + + #[tokio::test] + async fn handles_empty_batches() { + // Empty batches in the stream should be properly skipped without affecting zones. + let schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int32, false), + Field::new(ROW_ADDR, DataType::UInt64, false), + ])); + + let empty_batch = RecordBatch::new_empty(schema.clone()); + let valid_batch = batch(vec![1, 2, 3], vec![0, 0, 0], vec![0, 1, 2]); + + let stream = Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::iter(vec![ + Ok(empty_batch.clone()), + Ok(valid_batch), + Ok(empty_batch), + ]), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 10).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // One zone containing the 3 valid rows (empty batches skipped) + assert_eq!(stats.len(), 1); + assert_eq!(stats[0].sum, 6); + assert_eq!(stats[0].bound.fragment_id, 0); + assert_eq!(stats[0].bound.length, 3); + } + + #[tokio::test] + async fn handles_zone_capacity_one() { + // Each row becomes its own zone when capacity is 1. + let values = vec![10, 20, 30]; + let offsets = vec![0, 1, 2]; + let batch = batch(values.clone(), vec![0, 0, 0], offsets.clone()); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 1).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // Three zones, one per row (capacity=1) + assert_eq!(stats.len(), 3); + for (i, stat) in stats.iter().enumerate() { + assert_eq!(stat.bound.fragment_id, 0); + assert_eq!(stat.bound.start, offsets[i]); + assert_eq!(stat.bound.length, 1); // Each zone contains exactly one row + assert_eq!(stat.sum, values[i]); + } + } + + #[tokio::test] + async fn handles_large_capacity() { + // When capacity >> data size, all data fits in one zone. + let values = vec![1; 100]; + let offsets: Vec = (0..100).collect(); + let batch = batch(values, vec![0; 100], offsets); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 10000).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // One zone containing all 100 rows (capacity is large enough) + assert_eq!(stats.len(), 1); + assert_eq!(stats[0].sum, 100); + assert_eq!(stats[0].bound.start, 0); + assert_eq!(stats[0].bound.length, 100); + } + + #[tokio::test] + async fn rejects_zero_capacity() { + let processor = MockProcessor::new(); + let result = ZoneTrainer::new(processor, 0); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("zone capacity must be greater than zero")); + } + + #[tokio::test] + async fn handles_multiple_batches_same_fragment() { + // Multiple batches from the same fragment should be properly accumulated into zones. + let b1 = batch(vec![1, 1], vec![0, 0], vec![0, 1]); + let b2 = batch(vec![1, 1], vec![0, 0], vec![2, 3]); + let b3 = batch(vec![1, 1], vec![0, 0], vec![4, 5]); + + let stream = Box::pin(RecordBatchStreamAdapter::new( + b1.schema(), + stream::iter(vec![Ok(b1), Ok(b2), Ok(b3)]), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 4).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // Two zones: first 4 rows, then remaining 2 rows + assert_eq!(stats.len(), 2); + // First zone: offsets [0..=3] + assert_eq!(stats[0].bound.fragment_id, 0); + assert_eq!(stats[0].bound.start, 0); + assert_eq!(stats[0].bound.length, 4); + assert_eq!(stats[0].sum, 4); + // Second zone: offsets [4..=5] + assert_eq!(stats[1].bound.fragment_id, 0); + assert_eq!(stats[1].bound.start, 4); + assert_eq!(stats[1].bound.length, 2); + assert_eq!(stats[1].sum, 2); + } + + #[tokio::test] + async fn handles_multi_batch_with_fragment_change() { + // Complex scenario: multiple batches with fragment changes mid-batch. + // This tests that zones flush correctly at fragment boundaries. + let b1 = batch(vec![1, 1], vec![0, 0], vec![0, 1]); + // b2 has fragment change: starts with frag 0, switches to frag 1 + let b2 = batch(vec![1, 1, 2, 2], vec![0, 0, 1, 1], vec![2, 3, 0, 1]); + + let stream = Box::pin(RecordBatchStreamAdapter::new( + b1.schema(), + stream::iter(vec![Ok(b1), Ok(b2)]), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 3).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // Three zones: frag 0 full zone, frag 0 partial (flushed at boundary), frag 1 + assert_eq!(stats.len(), 3); + + // Zone 0: Fragment 0, offsets [0..=2] (fills capacity) + assert_eq!(stats[0].bound.fragment_id, 0); + assert_eq!(stats[0].bound.start, 0); + assert_eq!(stats[0].bound.length, 3); + assert_eq!(stats[0].sum, 3); + + // Zone 1: Fragment 0, offset 3 (partial, flushed at fragment boundary) + assert_eq!(stats[1].bound.fragment_id, 0); + assert_eq!(stats[1].bound.start, 3); + assert_eq!(stats[1].bound.length, 1); + assert_eq!(stats[1].sum, 1); + + // Zone 2: Fragment 1, offsets [0..=1] + assert_eq!(stats[2].bound.fragment_id, 1); + assert_eq!(stats[2].bound.start, 0); + assert_eq!(stats[2].bound.length, 2); + assert_eq!(stats[2].sum, 4); + } + + #[tokio::test] + async fn handles_non_contiguous_offsets_after_deletion() { + // CRITICAL: Test deletion scenario with non-contiguous row offsets. + // This is the main reason for tracking first/last offsets. + // Simulate a zone where rows 2, 3, 4, 6 have been deleted. + let values = vec![1, 1, 1, 1, 1, 1]; // 6 actual rows + let fragments = vec![0, 0, 0, 0, 0, 0]; + let offsets = vec![0, 1, 5, 7, 8, 9]; // Non-contiguous! + + let batch = batch(values, fragments, offsets); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 4).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // Should create 2 zones (capacity=4): + // Zone 0: rows at offsets [0, 1, 5, 7] (4 rows) + // Zone 1: rows at offsets [8, 9] (2 rows) + assert_eq!(stats.len(), 2); + + // First zone: 4 rows, but offset span is [0..=7] so length=8 (due to gaps) + assert_eq!(stats[0].sum, 4); + assert_eq!(stats[0].bound.fragment_id, 0); + assert_eq!(stats[0].bound.start, 0); + assert_eq!(stats[0].bound.length, 8); // Address span: 7 - 0 + 1 + + // Second zone: 2 rows, offset span is [8..=9] so length=2 + assert_eq!(stats[1].sum, 2); + assert_eq!(stats[1].bound.fragment_id, 0); + assert_eq!(stats[1].bound.start, 8); + assert_eq!(stats[1].bound.length, 2); // Address span: 9 - 8 + 1 + } + + #[tokio::test] + async fn handles_deletion_with_large_gaps() { + // Extreme deletion scenario: very large gaps between consecutive rows. + let values = vec![1, 1, 1]; + let fragments = vec![0, 0, 0]; + let offsets = vec![0, 100, 200]; // Huge gaps! + + let batch = batch(values, fragments, offsets); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 10).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // One zone with 3 rows, but offset span [0..=200] so length=201 due to large gaps + assert_eq!(stats.len(), 1); + assert_eq!(stats[0].sum, 3); + assert_eq!(stats[0].bound.start, 0); + assert_eq!(stats[0].bound.length, 201); // Span: 200 - 0 + 1 + } + + #[tokio::test] + async fn handles_non_contiguous_fragment_ids() { + // CRITICAL: Test fragment IDs that are not consecutive (e.g., after fragment deletion). + // Original code assumed fragment_id + 1, which would fail here. + // Fragment IDs: 0, 5, 10 (non-consecutive!) + let values = vec![1, 1, 2, 2, 3, 3]; + let fragments = vec![0, 0, 5, 5, 10, 10]; // Gaps in fragment IDs + let offsets = vec![0, 1, 0, 1, 0, 1]; + + let batch = batch(values, fragments, offsets); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let processor = MockProcessor::new(); + let trainer = ZoneTrainer::new(processor, 10).unwrap(); + let stats = trainer.train(stream).await.unwrap(); + + // Should create 3 zones (one per fragment) + assert_eq!(stats.len(), 3); + + // Fragment 0 + assert_eq!(stats[0].bound.fragment_id, 0); + assert_eq!(stats[0].bound.start, 0); + assert_eq!(stats[0].bound.length, 2); + assert_eq!(stats[0].sum, 2); + + // Fragment 5 (not 1!) + assert_eq!(stats[1].bound.fragment_id, 5); + assert_eq!(stats[1].bound.start, 0); + assert_eq!(stats[1].bound.length, 2); + assert_eq!(stats[1].sum, 4); + + // Fragment 10 (not 2!) + assert_eq!(stats[2].bound.fragment_id, 10); + assert_eq!(stats[2].bound.start, 0); + assert_eq!(stats[2].bound.length, 2); + assert_eq!(stats[2].sum, 6); + } + + #[test] + fn search_zones_collects_row_ranges() { + // Ensure the shared helper converts matching zones into the correct row-id + // ranges (fragment upper bits + local offsets) while skipping non-matching + // zones. This protects the helper if we modify how RowAddrTreeMap ranges are + // inserted in the future. + #[derive(Debug)] + struct DummyZone { + bound: ZoneBound, + matches: bool, + } + + impl AsRef for DummyZone { + fn as_ref(&self) -> &ZoneBound { + &self.bound + } + } + + let zones = vec![ + DummyZone { + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 2, + }, + matches: true, + }, + DummyZone { + bound: ZoneBound { + fragment_id: 1, + start: 5, + length: 3, + }, + matches: false, + }, + DummyZone { + bound: ZoneBound { + fragment_id: 2, + start: 10, + length: 1, + }, + matches: true, + }, + ]; + + let metrics = LocalMetricsCollector::default(); + let result = search_zones(&zones, &metrics, |zone| Ok(zone.matches)).unwrap(); + let SearchResult::AtMost(map) = result else { + panic!("search_zones should return AtMost for dummy zones"); + }; + + // Fragment 0, offsets 0 and 1 + assert!(map.contains(0)); + assert!(map.contains(1)); + // Fragment 1 should be skipped entirely + assert!(!map.contains((1_u64 << 32) + 5)); + assert!(!map.contains((1_u64 << 32) + 7)); + // Fragment 2 includes only the single offset 10 + assert!(map.contains((2_u64 << 32) + 10)); + assert!(!map.contains((2_u64 << 32) + 11)); + } + + #[test] + fn search_zones_returns_empty_when_no_match() { + #[derive(Debug)] + struct DummyZone { + bound: ZoneBound, + matches: bool, + } + + impl AsRef for DummyZone { + fn as_ref(&self) -> &ZoneBound { + &self.bound + } + } + + // Both zones are marked as non-matching. The helper should return an empty map. + let zones = vec![ + DummyZone { + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 4, + }, + matches: false, + }, + DummyZone { + bound: ZoneBound { + fragment_id: 1, + start: 10, + length: 2, + }, + matches: false, + }, + ]; + + let metrics = LocalMetricsCollector::default(); + let result = search_zones(&zones, &metrics, |zone| Ok(zone.matches)).unwrap(); + let SearchResult::AtMost(map) = result else { + panic!("expected AtMost result"); + }; + // No zones should be inserted when every predicate evaluates to false + assert!(map.is_empty()); + } + + #[tokio::test] + async fn rebuild_zones_appends_new_stats() { + let existing = vec![MockStats { + sum: 50, + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 2, + }, + }]; + + let batch = batch(vec![3, 4], vec![1, 1], vec![0, 1]); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let trainer = ZoneTrainer::new(MockProcessor::new(), 2).unwrap(); + let rebuilt = rebuild_zones(&existing, trainer, stream).await.unwrap(); + // Existing zone should remain unchanged and new stats appended afterwards + assert_eq!(rebuilt.len(), 2); + assert_eq!(rebuilt[0].sum, 50); + assert_eq!(rebuilt[1].sum, 7); + assert_eq!(rebuilt[1].bound.fragment_id, 1); + assert_eq!(rebuilt[1].bound.start, 0); + assert_eq!(rebuilt[1].bound.length, 2); + } + + #[tokio::test] + async fn rebuild_zones_handles_multi_fragment_stream() { + let existing = vec![MockStats { + sum: 10, + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 1, + }, + }]; + + // Construct a stream with two fragments. Trainer should emit two zones that + // get appended after the existing entries. + let batch = batch(vec![5, 5, 6, 6], vec![1, 1, 2, 2], vec![0, 1, 0, 1]); + let stream = Box::pin(RecordBatchStreamAdapter::new( + batch.schema(), + stream::once(async { Ok(batch) }), + )); + + let trainer = ZoneTrainer::new(MockProcessor::new(), 2).unwrap(); + let rebuilt = rebuild_zones(&existing, trainer, stream).await.unwrap(); + // Existing zone plus two new fragments should yield three total zones + assert_eq!(rebuilt.len(), 3); + assert_eq!(rebuilt[0].bound.fragment_id, 0); + assert_eq!(rebuilt[1].bound.fragment_id, 1); + assert_eq!(rebuilt[2].bound.fragment_id, 2); + assert_eq!(rebuilt[1].sum, 10); + assert_eq!(rebuilt[2].sum, 12); + } +} diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 0b4b94e7e30..f41d97ee57d 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -23,10 +23,7 @@ use crate::scalar::{ use crate::Any; use datafusion::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_expr::Accumulator; -use futures::TryStreamExt; use lance_core::cache::{LanceCache, WeakLanceCache}; -use lance_core::ROW_ADDR; -use lance_datafusion::chunker::chunk_concat_stream; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; @@ -42,29 +39,18 @@ use crate::vector::VectorIndex; use crate::{Index, IndexType}; use async_trait::async_trait; use deepsize::DeepSizeOf; +use lance_core::Error; use lance_core::Result; -use lance_core::{utils::address::RowAddress, utils::mask::RowAddrTreeMap, Error}; use roaring::RoaringBitmap; use snafu::location; + +use super::zoned::{rebuild_zones, search_zones, ZoneBound, ZoneProcessor, ZoneTrainer}; const ROWS_PER_ZONE_DEFAULT: u64 = 8192; // 1 zone every two batches const ZONEMAP_FILENAME: &str = "zonemap.lance"; const ZONEMAP_SIZE_META_KEY: &str = "rows_per_zone"; const ZONEMAP_INDEX_VERSION: u32 = 0; -// -// Example: Suppose we have two fragments, each with 4 rows. -// Fragment 0: zone_start = 0, zone_length = 4 // covers rows 0, 1, 2, 3 in fragment 0 -// The row addresses for fragment 0 are: 0, 1, 2, 3 -// Fragment 1: zone_start = 0, zone_length = 4 // covers rows 0, 1, 2, 3 in fragment 1 -// The row addresses for fragment 1 are: 32>>1, 32>>1 + 1, 32>>1 + 2, 32>>1 + 3 -// -// Deletion is 0 index based. We delete the 0th and 1st row in fragment 0, -// and the 1st and 2nd row in fragment 1, -// Fragment 0: zone_start = 2, zone_length = 2 // covers rows 2, 3 in fragment 0 -// The row addresses for fragment 0 are: 2, 3 -// Fragment 1: zone_start = 0, zone_length = 4 // covers rows 0, 3 in fragment 1 -// The row addresses for fragment 1 are: 32>>1, 32>>1 + 3 /// Basic stats about zonemap index #[derive(Debug, PartialEq, Clone)] struct ZoneMapStatistics { @@ -73,14 +59,9 @@ struct ZoneMapStatistics { null_count: u32, // only apply to float type nan_count: u32, - fragment_id: u64, - // zone_start is start row of the zone in the fragment, also known - // as the local offset. To get the actual first row address, - // you can do `fragment_id << 32 + zone_start` - zone_start: u64, - // zone_length is the `row offset span` between the first and the last row in the zone - // calculated as: (last_row_offset - first_row_offset + 1) - zone_length: usize, + // Bound of this zone within the fragment. Persisted as three separate columns + // (fragment_id, zone_start, zone_length) in the index file. + bound: ZoneBound, } impl DeepSizeOf for ZoneMapStatistics { @@ -93,6 +74,12 @@ impl DeepSizeOf for ZoneMapStatistics { } } +impl AsRef for ZoneMapStatistics { + fn as_ref(&self) -> &ZoneBound { + &self.bound + } +} + /// ZoneMap index /// At high level it's a columnar database technique for predicate push down and scan pruning. /// It breaks data into fixed-size chunks called `zones` and store summary statistics(min, max, null_count, @@ -475,15 +462,16 @@ impl ZoneMapIndex { let max = ScalarValue::try_from_array(max_col, i)?; let null_count = null_count_col.value(i); let nan_count = nan_count_col.value(i); - zones.push(ZoneMapStatistics { min, max, null_count, nan_count, - fragment_id: fragment_id_col.value(i), - zone_start: zone_start_col.value(i), - zone_length: zone_length.value(i) as usize, + bound: ZoneBound { + fragment_id: fragment_id_col.value(i), + start: zone_start_col.value(i), + length: zone_length.value(i) as usize, + }, }); } @@ -536,7 +524,7 @@ impl Index for ZoneMapIndex { // Loop through zones and add unique fragment IDs to the bitmap for zone in &self.zones { - frag_ids.insert(zone.fragment_id as u32); + frag_ids.insert(zone.bound.fragment_id as u32); } Ok(frag_ids) @@ -550,25 +538,10 @@ impl ScalarIndex for ZoneMapIndex { query: &dyn AnyQuery, metrics: &dyn MetricsCollector, ) -> Result { - metrics.record_comparisons(self.zones.len()); let query = query.as_any().downcast_ref::().unwrap(); - - let mut row_addr_tree_map = RowAddrTreeMap::new(); - - // Loop through zones and check each one - for zone in self.zones.iter() { - // Check if this zone matches the query - if self.evaluate_zone_against_query(zone, query)? { - // Calculate the range of row addresses for this zone - let zone_start_addr = (zone.fragment_id << 32) + zone.zone_start; - let zone_end_addr = zone_start_addr + zone.zone_length as u64; - - // Add all row addresses in this zone to the result - row_addr_tree_map.insert_range(zone_start_addr..zone_end_addr); - } - } - - Ok(SearchResult::AtMost(row_addr_tree_map)) + search_zones(&self.zones, metrics, |zone| { + self.evaluate_zone_against_query(zone, query) + }) } fn can_remap(&self) -> bool { @@ -593,34 +566,20 @@ impl ScalarIndex for ZoneMapIndex { new_data: SendableRecordBatchStream, dest_store: &dyn IndexStore, ) -> Result { - // Process the new data to create zones - let batches_source = new_data; - let value_type = batches_source.schema().field(0).data_type().clone(); - - let mut builder = ZoneMapIndexBuilder::try_new( - ZoneMapIndexBuilderParams::new(self.rows_per_zone), - value_type, - )?; - - builder.train(batches_source).await?; + // Train new zones for the incoming data stream + let schema = new_data.schema(); + let value_type = schema.field(0).data_type().clone(); - // Get the new zones from the builder - let new_zone_stats = builder.maps; + let options = ZoneMapIndexBuilderParams::new(self.rows_per_zone); + let processor = ZoneMapProcessor::new(value_type.clone())?; + let trainer = ZoneTrainer::new(processor, self.rows_per_zone)?; + let updated_zones = rebuild_zones(&self.zones, trainer, new_data).await?; - // Combine existing zones with new zones - let mut all_zones = self.zones.clone(); - all_zones.extend(new_zone_stats); - - // Create a new builder with all zones to write them out - let mut combined_builder = ZoneMapIndexBuilder::try_new( - ZoneMapIndexBuilderParams::new(self.rows_per_zone), - self.data_type.clone(), - )?; - combined_builder.maps = all_zones; - combined_builder.options.rows_per_zone = self.rows_per_zone; - - // Write the updated index to dest_store - combined_builder.write_index(dest_store).await?; + // Serialize the combined zones back into the index file + let mut builder = ZoneMapIndexBuilder::try_new(options, self.data_type.clone())?; + builder.options.rows_per_zone = self.rows_per_zone; + builder.maps = updated_zones; + builder.write_index(dest_store).await?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&pbold::ZoneMapIndexDetails::default()) @@ -682,206 +641,24 @@ pub struct ZoneMapIndexBuilder { items_type: DataType, maps: Vec, - // The local offset within the current zone - cur_zone_offset: usize, - cur_fragment_id: u32, - // Track the actual first and last row offsets in the current zone - // This handles non-contiguous offsets after deletions - cur_zone_first_row_offset: Option, - cur_zone_last_row_offset: Option, - - min: MinAccumulator, - max: MaxAccumulator, - null_count: u32, - nan_count: u32, } impl ZoneMapIndexBuilder { pub fn try_new(options: ZoneMapIndexBuilderParams, items_type: DataType) -> Result { - let min = MinAccumulator::try_new(&items_type)?; - let max = MaxAccumulator::try_new(&items_type)?; Ok(Self { options, items_type, maps: Vec::new(), - cur_zone_offset: 0, - cur_fragment_id: 0, - cur_zone_first_row_offset: None, - cur_zone_last_row_offset: None, - min, - max, - null_count: 0, - nan_count: 0, }) } - fn count_nans(array: &ArrayRef) -> u32 { - match array.data_type() { - DataType::Float16 => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); - array.values().iter().filter(|&&x| x.is_nan()).count() as u32 - } - DataType::Float32 => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); - array.values().iter().filter(|&&x| x.is_nan()).count() as u32 - } - DataType::Float64 => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); - array.values().iter().filter(|&&x| x.is_nan()).count() as u32 - } - _ => 0, // Non-float types don't have NaNs - } - } - - fn update_stats(&mut self, array: &ArrayRef) -> Result<()> { - self.null_count += array.null_count() as u32; - self.nan_count += Self::count_nans(array); - self.min.update_batch(std::slice::from_ref(array))?; - self.max.update_batch(std::slice::from_ref(array))?; - Ok(()) - } - - fn new_map(&mut self, fragment_id: u32) -> Result<()> { - let zone_start = self.cur_zone_first_row_offset.unwrap_or(0) as u64; - let zone_length = self - .cur_zone_last_row_offset - .map(|last_row_offset| { - (last_row_offset - self.cur_zone_first_row_offset.unwrap_or(0) + 1) as usize - }) - .unwrap_or(self.cur_zone_offset); - - let new_map = ZoneMapStatistics { - min: self.min.evaluate()?, - max: self.max.evaluate()?, - null_count: self.null_count, - nan_count: self.nan_count, - fragment_id: fragment_id as u64, - zone_start, - zone_length, - }; - - self.maps.push(new_map); - - self.cur_zone_offset = 0; - self.cur_zone_first_row_offset = None; - self.cur_zone_last_row_offset = None; - self.min = MinAccumulator::try_new(&self.items_type)?; - self.max = MaxAccumulator::try_new(&self.items_type)?; - self.null_count = 0; - self.nan_count = 0; - Ok(()) - } - + /// Train the builder using the shared zone trainer. The input stream must contain + /// the value column followed by `_rowaddr`, matching the dataset scan order enforced + /// by the scalar index registry. pub async fn train(&mut self, batches_source: SendableRecordBatchStream) -> Result<()> { - assert!(batches_source.schema().field_with_name(ROW_ADDR).is_ok()); - - let mut batches_source = - chunk_concat_stream(batches_source, self.options.rows_per_zone as usize); - - while let Some(batch) = batches_source.try_next().await? { - if batch.num_rows() == 0 { - continue; - } - - let data_array: &arrow_array::ArrayRef = batch.column(0); - let row_addrs_array = batch - .column_by_name(ROW_ADDR) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - - let mut remaining = batch.num_rows(); - let mut array_offset: usize = 0; - - // Initialize cur_fragment_id from the first row address if this is the first batch - if self.maps.is_empty() && self.cur_zone_offset == 0 { - let first_row_addr = row_addrs_array.value(0); - self.cur_fragment_id = (first_row_addr >> 32) as u32; - } - - while remaining > 0 { - // Find the next fragment boundary in this batch - let next_fragment_index = (array_offset..row_addrs_array.len()).find(|&i| { - let row_addr = row_addrs_array.value(i); - let fragment_id = (row_addr >> 32) as u32; - fragment_id == self.cur_fragment_id + 1 - }); - let empty_rows_left_in_cur_zone: usize = - (self.options.rows_per_zone - self.cur_zone_offset as u64) as usize; - - // Check if there is enough data from the current fragment to fill the current zone - let desired = if let Some(idx) = next_fragment_index { - self.cur_fragment_id = (row_addrs_array.value(idx) >> 32) as u32; - // Take the minimum between distance to boundary and space left in zone - // to ensure we don't exceed the zone size limit - std::cmp::min(idx - array_offset, empty_rows_left_in_cur_zone) - } else { - empty_rows_left_in_cur_zone - }; - - if desired > remaining { - // Not enough data to fill a map, just increment counts - self.update_stats(&data_array.slice(array_offset, remaining))?; - - // Track first and last row offsets (local offsets within fragment) - let first_row_offset = - RowAddress::new_from_u64(row_addrs_array.value(array_offset)).row_offset(); - let last_row_offset = RowAddress::new_from_u64( - row_addrs_array.value(array_offset + remaining - 1), - ) - .row_offset(); - if self.cur_zone_first_row_offset.is_none() { - self.cur_zone_first_row_offset = Some(first_row_offset); - } - self.cur_zone_last_row_offset = Some(last_row_offset); - - self.cur_zone_offset += remaining; - break; - } else if desired > 0 { - // There is enough data, create a new zone map - self.update_stats(&data_array.slice(array_offset, desired))?; - - // Track first and last row offsets - let first_row_offset = - RowAddress::new_from_u64(row_addrs_array.value(array_offset)).row_offset(); - let last_row_offset = - RowAddress::new_from_u64(row_addrs_array.value(array_offset + desired - 1)) - .row_offset(); - if self.cur_zone_first_row_offset.is_none() { - self.cur_zone_first_row_offset = Some(first_row_offset); - } - self.cur_zone_last_row_offset = Some(last_row_offset); - - self.cur_zone_offset += desired; - self.new_map((row_addrs_array.value(array_offset) >> 32) as u32)?; - } else if desired == 0 { - // The new batch starts with a new fragment. Flush the current zone if it's not empty - if self.cur_zone_offset > 0 { - self.new_map(self.cur_fragment_id.wrapping_sub(1))?; - } - // Let the loop run again - // to find the next fragment boundary - continue; - } - array_offset += desired; - remaining = remaining.saturating_sub(desired); - } - } - // Create the final map - if self.cur_zone_offset > 0 { - self.new_map(self.cur_fragment_id)?; - } - + let processor = ZoneMapProcessor::new(self.items_type.clone())?; + let trainer = ZoneTrainer::new(processor, self.options.rows_per_zone)?; + self.maps = trainer.train(batches_source).await?; Ok(()) } @@ -903,13 +680,13 @@ impl ZoneMapIndexBuilder { let nan_counts = UInt32Array::from_iter_values(self.maps.iter().map(|stat| stat.nan_count)); let fragment_ids = - UInt64Array::from_iter_values(self.maps.iter().map(|stat| stat.fragment_id)); + UInt64Array::from_iter_values(self.maps.iter().map(|stat| stat.bound.fragment_id)); let zone_lengths = - UInt64Array::from_iter_values(self.maps.iter().map(|stat| stat.zone_length as u64)); + UInt64Array::from_iter_values(self.maps.iter().map(|stat| stat.bound.length as u64)); let zone_starts = - UInt64Array::from_iter_values(self.maps.iter().map(|stat| stat.zone_start)); + UInt64Array::from_iter_values(self.maps.iter().map(|stat| stat.bound.start)); let schema = Arc::new(arrow_schema::Schema::new(vec![ // min and max can be null if the entire batch is null values @@ -952,6 +729,87 @@ impl ZoneMapIndexBuilder { } } +/// Index-specific processor that computes min/max statistics for each zone while the +/// trainer takes care of chunking and fragment boundaries. +struct ZoneMapProcessor { + data_type: DataType, + min: MinAccumulator, + max: MaxAccumulator, + null_count: u32, + nan_count: u32, +} + +impl ZoneMapProcessor { + fn new(data_type: DataType) -> Result { + let min = MinAccumulator::try_new(&data_type)?; + let max = MaxAccumulator::try_new(&data_type)?; + Ok(Self { + data_type, + min, + max, + null_count: 0, + nan_count: 0, + }) + } + + fn count_nans(array: &ArrayRef) -> u32 { + match array.data_type() { + DataType::Float16 => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + array.values().iter().filter(|&&x| x.is_nan()).count() as u32 + } + DataType::Float32 => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + array.values().iter().filter(|&&x| x.is_nan()).count() as u32 + } + DataType::Float64 => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + array.values().iter().filter(|&&x| x.is_nan()).count() as u32 + } + _ => 0, + } + } +} + +impl ZoneProcessor for ZoneMapProcessor { + type ZoneStatistics = ZoneMapStatistics; + + fn process_chunk(&mut self, array: &ArrayRef) -> Result<()> { + self.null_count += array.null_count() as u32; + self.nan_count += Self::count_nans(array); + self.min.update_batch(std::slice::from_ref(array))?; + self.max.update_batch(std::slice::from_ref(array))?; + Ok(()) + } + + fn finish_zone(&mut self, bound: ZoneBound) -> Result { + Ok(ZoneMapStatistics { + min: self.min.evaluate()?, + max: self.max.evaluate()?, + null_count: self.null_count, + nan_count: self.nan_count, + bound, + }) + } + + fn reset(&mut self) -> Result<()> { + self.min = MinAccumulator::try_new(&self.data_type)?; + self.max = MaxAccumulator::try_new(&self.data_type)?; + self.null_count = 0; + self.nan_count = 0; + Ok(()) + } +} + #[derive(Debug, Default)] pub struct ZoneMapIndexPlugin; @@ -1080,6 +938,7 @@ mod tests { use crate::scalar::{zonemap::ROWS_PER_ZONE_DEFAULT, IndexStore}; use std::sync::Arc; + use crate::scalar::zoned::ZoneBound; use crate::scalar::zonemap::{ZoneMapIndexPlugin, ZoneMapStatistics}; use arrow::datatypes::Float32Type; use arrow_array::{Array, RecordBatch, UInt64Array}; @@ -1211,8 +1070,8 @@ mod tests { for (i, zone) in index.zones.iter().enumerate() { assert_eq!(zone.null_count, 1000); assert_eq!(zone.nan_count, 0, "Zone {} should have nan_count = 0", i); - assert_eq!(zone.zone_length, 5000); - assert_eq!(zone.fragment_id, i as u64); + assert_eq!(zone.bound.length, 5000); + assert_eq!(zone.bound.fragment_id, i as u64); } // Equals query: null (should match all zones since they contain null values) @@ -1265,8 +1124,8 @@ mod tests { // Verify the new zone was added let new_zone = &updated_index.zones[10]; // Last zone should be the new one - assert_eq!(new_zone.fragment_id, 10u64); // New fragment ID - assert_eq!(new_zone.zone_length, 5000); + assert_eq!(new_zone.bound.fragment_id, 10u64); // New fragment ID + assert_eq!(new_zone.bound.length, 5000); assert_eq!(new_zone.null_count, 0); // New data has no nulls assert_eq!(new_zone.nan_count, 0); // New data has no NaN values @@ -1360,12 +1219,12 @@ mod tests { for (i, zone) in index.zones.iter().enumerate() { assert_eq!(zone.nan_count, 20, "Zone {} should have 20 NaN values", i); assert_eq!( - zone.zone_length, 100, + zone.bound.length, 100, "Zone {} should have zone_length 100", i ); assert_eq!( - zone.fragment_id, 0u64, + zone.bound.fragment_id, 0u64, "Zone {} should have fragment_id 0", i ); @@ -1583,18 +1442,22 @@ mod tests { max: ScalarValue::Int32(Some(99)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 0, - zone_length: 100, + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 100, + }, }, ZoneMapStatistics { min: ScalarValue::Int32(Some(100)), max: ScalarValue::Int32(Some(100)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 100, - zone_length: 1, + bound: ZoneBound { + fragment_id: 0, + start: 100, + length: 1, + }, } ] ); @@ -1761,27 +1624,33 @@ mod tests { max: ScalarValue::Int64(Some(8191)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 0, - zone_length: 8192, + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 8192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(8192)), max: ScalarValue::Int64(Some(16383)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 8192, - zone_length: 8192, + bound: ZoneBound { + fragment_id: 0, + start: 8192, + length: 8192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(16384)), max: ScalarValue::Int64(Some(16425)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 16384, - zone_length: 42, + bound: ZoneBound { + fragment_id: 0, + start: 16384, + length: 42, + }, } ] ); @@ -1915,45 +1784,55 @@ mod tests { max: ScalarValue::Int64(Some(4999)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 0, - zone_length: 5000, + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 5000, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(5000)), max: ScalarValue::Int64(Some(8191)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 5000, - zone_length: 3192, + bound: ZoneBound { + fragment_id: 0, + start: 5000, + length: 3192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(8192)), max: ScalarValue::Int64(Some(13191)), null_count: 0, nan_count: 0, - fragment_id: 1, - zone_start: 0, - zone_length: 5000, + bound: ZoneBound { + fragment_id: 1, + start: 0, + length: 5000, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(13192)), max: ScalarValue::Int64(Some(16383)), null_count: 0, nan_count: 0, - fragment_id: 1, - zone_start: 5000, - zone_length: 3192, + bound: ZoneBound { + fragment_id: 1, + start: 5000, + length: 3192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(16384)), max: ScalarValue::Int64(Some(16425)), null_count: 0, nan_count: 0, - fragment_id: 2, - zone_start: 0, - zone_length: 42, + bound: ZoneBound { + fragment_id: 2, + start: 0, + length: 42, + }, } ] ); @@ -2113,27 +1992,33 @@ mod tests { max: ScalarValue::Int64(Some(8191)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 0, - zone_length: 8192, + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 8192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(8192)), max: ScalarValue::Int64(Some(16383)), null_count: 0, nan_count: 0, - fragment_id: 1, - zone_start: 0, - zone_length: 8192, + bound: ZoneBound { + fragment_id: 1, + start: 0, + length: 8192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(16384)), max: ScalarValue::Int64(Some(16425)), null_count: 0, nan_count: 0, - fragment_id: 2, - zone_start: 0, - zone_length: 42, + bound: ZoneBound { + fragment_id: 2, + start: 0, + length: 42, + }, } ] ); @@ -2182,27 +2067,33 @@ mod tests { max: ScalarValue::Int64(Some(8191)), null_count: 0, nan_count: 0, - fragment_id: 0, - zone_start: 0, - zone_length: 8192, + bound: ZoneBound { + fragment_id: 0, + start: 0, + length: 8192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(8192)), max: ScalarValue::Int64(Some(16383)), null_count: 0, nan_count: 0, - fragment_id: 1, - zone_start: 0, - zone_length: 8192, + bound: ZoneBound { + fragment_id: 1, + start: 0, + length: 8192, + }, }, ZoneMapStatistics { min: ScalarValue::Int64(Some(16384)), max: ScalarValue::Int64(Some(16425)), null_count: 0, nan_count: 0, - fragment_id: 2, - zone_start: 0, - zone_length: 42, + bound: ZoneBound { + fragment_id: 2, + start: 0, + length: 42, + }, } ] );