diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 3358f9093d5..e39ed73662f 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -3,15 +3,18 @@ use std::sync::Arc; +use arrow::array::ArrayData; +use arrow::datatypes::DataType; use arrow_array::{cast::AsArray, Array, ArrayRef, FixedSizeListArray, RecordBatch}; +use arrow_buffer::{Buffer, MutableBuffer}; use futures::StreamExt; -use lance_arrow::{interleave_batches, DataTypeExt}; +use lance_arrow::DataTypeExt; use lance_core::datatypes::Schema; use lance_linalg::distance::DistanceType; use log::{info, warn}; use rand::rngs::SmallRng; use rand::seq::{IteratorRandom, SliceRandom}; -use rand::SeedableRng; +use rand::{Rng, SeedableRng}; use snafu::location; use tokio::sync::Mutex; @@ -293,98 +296,53 @@ pub async fn maybe_sample_training_data( _ => sample_size_hint, }; - let batch = if num_rows > sample_size_hint && !is_nullable { - let projection = dataset.schema().project(&[column])?; - let batch = dataset.sample(sample_size_hint, &projection).await?; - info!( - "Sample training data: retrieved {} rows by sampling", - batch.num_rows() - ); - batch - } else if num_rows > sample_size_hint && is_nullable { - // Use min block size + vector size to determine sample granularity - // For example, on object storage, block size is 64 KB. A 768-dim 32-bit - // vector is 3 KB. So we can sample every 64 KB / 3 KB = 21 vectors. - let block_size = dataset.object_store().block_size(); - // We provide a fallback in case of multi-vector, which will have - // a variable size. We use 4 KB as a fallback. - let byte_width = vector_field - .data_type() - .byte_width_opt() - .unwrap_or(4 * 1024); - - let ranges = random_ranges(num_rows, sample_size_hint, block_size, byte_width); - - let mut collected = Vec::with_capacity(ranges.size_hint().0); - let mut indices = Vec::with_capacity(sample_size_hint); - let mut num_non_null = 0; - - let mut scan = dataset.take_scan( - Box::pin(futures::stream::iter(ranges).map(Ok)), - Arc::new(dataset.schema().project(&[column])?), - dataset.object_store().io_parallelism(), - ); - - while let Some(batch) = scan.next().await { - let batch = batch?; - - let array = get_column_from_batch(&batch, column)?; - let null_count = array.logical_null_count(); - if null_count < array.len() { - num_non_null += array.len() - null_count; + let should_sample = num_rows > sample_size_hint; + if should_sample { + sample_training_data( + dataset, + column, + sample_size_hint, + num_rows, + vector_field, + is_nullable, + ) + .await + } else { + // too small to require sampling + let batch = scan_all_training_data(dataset, column, is_nullable).await?; + vector_column_to_fsl(&batch, column) + } +} - let batch_i = collected.len(); - if let Some(null_buffer) = array.nulls() { - for i in null_buffer.valid_indices() { - indices.push((batch_i, i)); - } - } else { - indices.extend((0..array.len()).map(|i| (batch_i, i))); - } +#[derive(Debug)] +pub struct PartitionLoadLock { + partition_locks: Vec>>, +} - collected.push(batch); - } - if num_non_null >= sample_size_hint { - break; - } +impl PartitionLoadLock { + pub fn new(num_partitions: usize) -> Self { + Self { + partition_locks: (0..num_partitions) + .map(|_| Arc::new(Mutex::new(()))) + .collect(), } + } - let batch = interleave_batches(&collected, &indices).map_err(|err| Error::Index { - message: format!("Sample training data: {}", err), - location: location!(), - })?; - info!( - "Sample training data: retrieved {} rows by sampling after filtering out nulls", - batch.num_rows() - ); - - // it's possible that we have more rows than sample_size_hint for this case, - // truncate the batch to sample_size_hint - if batch.num_rows() > sample_size_hint { - batch.slice(0, sample_size_hint) - } else { - batch - } - } else { - let mut scanner = dataset.scan(); - scanner.project(&[column])?; - if is_nullable { - let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; - scanner.filter_expr(column_expr.is_not_null()); - } - let batch = scanner.try_into_batch().await?; - info!( - "Sample training data: retrieved {} rows scanning full datasets", - batch.num_rows() - ); - batch - }; + pub fn get_partition_mutex(&self, partition_id: usize) -> Arc> { + let mtx = &self.partition_locks[partition_id]; - let array = get_column_from_batch(&batch, column)?; + mtx.clone() + } +} +/// Extract a vector column from a batch as a flat [`FixedSizeListArray`]. +/// +/// Handles both regular vector columns (FixedSizeList) and multivector columns +/// (List\), flattening the latter. +fn vector_column_to_fsl(batch: &RecordBatch, column: &str) -> Result { + let array = get_column_from_batch(batch, column)?; match array.data_type() { arrow::datatypes::DataType::FixedSizeList(_, _) => Ok(array.as_fixed_size_list().clone()), - // for multivector, flatten the vectors into a FixedSizeListArray arrow::datatypes::DataType::List(_) => { let list_array = array.as_list::(); let vectors = list_array.values().as_fixed_size_list(); @@ -392,7 +350,7 @@ pub async fn maybe_sample_training_data( } _ => Err(Error::Index { message: format!( - "Sample training data: column {} is not a FixedSizeListArray", + "Sample training data: column {} is not a vector column", column ), location: location!(), @@ -400,27 +358,338 @@ pub async fn maybe_sample_training_data( } } -#[derive(Debug)] -pub struct PartitionLoadLock { - partition_locks: Vec>>, +/// Scan the entire dataset to collect training data, optionally filtering nulls. +/// +/// Used when the dataset is small enough that random sampling is unnecessary. +async fn scan_all_training_data( + dataset: &Dataset, + column: &str, + is_nullable: bool, +) -> Result { + let mut scanner = dataset.scan(); + scanner.project(&[column])?; + if is_nullable { + let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; + scanner.filter_expr(column_expr.is_not_null()); + } + let batch = scanner.try_into_batch().await?; + info!( + "Sample training data: retrieved {} rows scanning full dataset", + batch.num_rows() + ); + Ok(batch) } -impl PartitionLoadLock { - pub fn new(num_partitions: usize) -> Self { - Self { - partition_locks: (0..num_partitions) - .map(|_| Arc::new(Mutex::new(()))) - .collect(), +/// Sample training data from the dataset. +/// +/// Dispatches to the most efficient strategy based on column type and nullability: +/// - Non-nullable FSL: [`sample_fsl_uniform`] — true uniform random row indices via chunked `take`. +/// - Nullable FSL: [`sample_nullable_fsl`] — streaming range-based reads with null filtering. +/// - Non-FSL (multivector): [`sample_nullable_fallback`] — streaming range-based reads. +async fn sample_training_data( + dataset: &Dataset, + column: &str, + sample_size_hint: usize, + num_rows: usize, + vector_field: &lance_core::datatypes::Field, + is_nullable: bool, +) -> Result { + let byte_width = vector_field + .data_type() + .byte_width_opt() + .unwrap_or(4 * 1024); + + match vector_field.data_type() { + DataType::FixedSizeList(_, _) if !is_nullable => { + sample_fsl_uniform( + dataset, + column, + sample_size_hint, + num_rows, + byte_width, + vector_field, + ) + .await + } + DataType::FixedSizeList(_, _) => { + let scan = + sample_training_data_scan(dataset, column, sample_size_hint, num_rows, byte_width)?; + sample_nullable_fsl(column, sample_size_hint, byte_width, vector_field, scan).await + } + _ => { + let scan = + sample_training_data_scan(dataset, column, sample_size_hint, num_rows, byte_width)?; + sample_nullable_fallback(column, sample_size_hint, is_nullable, scan).await } } +} - pub fn get_partition_mutex(&self, partition_id: usize) -> Arc> { - let mtx = &self.partition_locks[partition_id]; +/// Create a streaming scan over random ranges for sampling. +fn sample_training_data_scan( + dataset: &Dataset, + column: &str, + sample_size_hint: usize, + num_rows: usize, + byte_width: usize, +) -> Result { + let block_size = dataset.object_store().block_size(); + let ranges = random_ranges(num_rows, sample_size_hint, block_size, byte_width); + Ok(dataset.take_scan( + Box::pin(futures::stream::iter(ranges).map(Ok)), + Arc::new(dataset.schema().project(&[column])?), + dataset.object_store().io_parallelism(), + )) +} - mtx.clone() +/// Build a FixedSizeListArray from raw flat value bytes. +fn fsl_values_to_array( + field: &lance_core::datatypes::Field, + mut values_buf: MutableBuffer, + num_rows: usize, +) -> Result { + let (inner_field, dim) = match field.data_type() { + DataType::FixedSizeList(f, d) => (f, d as usize), + other => { + return Err(Error::Index { + message: format!("Expected FixedSizeList, got {:?}", other), + location: location!(), + }) + } + }; + + let elem_size = inner_field + .data_type() + .primitive_width() + .ok_or_else(|| Error::Index { + message: format!( + "FixedSizeList inner type {:?} has no fixed width", + inner_field.data_type() + ), + location: location!(), + })?; + + let expected_bytes = num_rows * dim * elem_size; + debug_assert_eq!(values_buf.len(), expected_bytes); + values_buf.truncate(expected_bytes); + let buf: Buffer = values_buf.into(); + let values_array = arrow_array::make_array(ArrayData::try_new( + inner_field.data_type().clone(), + num_rows * dim, + None, + 0, + vec![buf], + vec![], + )?); + + Ok(FixedSizeListArray::try_new( + inner_field, + dim as i32, + values_array, + None, + )?) +} + +/// Stream-and-compact sampling for nullable FixedSizeList vector columns. +/// +/// Unlike [`sample_nullable_fallback`], which must collect all source batches +/// in memory, this exploits the fixed-width layout of FSL columns to +/// accumulate non-null vector bytes directly into a flat buffer, dropping +/// each source batch immediately. This keeps peak memory proportional to the +/// output sample rather than the input scan. +async fn sample_nullable_fsl( + column: &str, + sample_size_hint: usize, + byte_width: usize, + vector_field: &lance_core::datatypes::Field, + mut scan: crate::dataset::scanner::DatasetRecordBatchStream, +) -> Result { + let mut values_buf = MutableBuffer::with_capacity(sample_size_hint * byte_width); + let mut num_non_null: usize = 0; + + while num_non_null < sample_size_hint { + let Some(batch) = scan.next().await else { + break; + }; + let batch = batch?; + let array = get_column_from_batch(&batch, column)?; + if array.logical_null_count() >= array.len() { + continue; + } + accumulate_fsl_values(&mut values_buf, &mut num_non_null, &array, byte_width, true)?; + } + + let num_rows_out = num_non_null.min(sample_size_hint); + values_buf.truncate(num_rows_out * byte_width); + + info!( + "Sample training data: retrieved {} rows by sampling after filtering out nulls", + num_rows_out + ); + + fsl_values_to_array(vector_field, values_buf, num_rows_out) +} + +/// True uniform random sampling for non-nullable FixedSizeList columns. +/// +/// Generates truly random row indices, sorts them, and fetches via +/// `dataset.take()` in chunks. Each chunk's RecordBatch is consumed into a flat +/// byte buffer and dropped immediately, keeping peak memory proportional to the +/// output sample. +async fn sample_fsl_uniform( + dataset: &Dataset, + column: &str, + sample_size_hint: usize, + num_rows: usize, + byte_width: usize, + vector_field: &lance_core::datatypes::Field, +) -> Result { + let indices = generate_random_indices(num_rows, sample_size_hint); + let projection = Arc::new(dataset.schema().project(&[column])?); + + let mut values_buf = MutableBuffer::with_capacity(sample_size_hint * byte_width); + let mut total_rows: usize = 0; + + const TAKE_CHUNK_SIZE: usize = 8192; + for chunk in indices.chunks(TAKE_CHUNK_SIZE) { + let batch = dataset.take(chunk, projection.clone()).await?; + let array = get_column_from_batch(&batch, column)?; + accumulate_fsl_values(&mut values_buf, &mut total_rows, &array, byte_width, false)?; + } + + info!( + "Sample training data: retrieved {} rows by uniform random sampling", + total_rows, + ); + + fsl_values_to_array(vector_field, values_buf, total_rows) +} + +/// Append values from a FixedSizeList array into a flat byte buffer. +/// +/// When `filter_nulls` is false and there are no nulls, copies raw bytes +/// directly from the FSL values buffer (accounting for child array offset). +/// When `filter_nulls` is true, uses Arrow's `filter` kernel to remove nulls. +fn accumulate_fsl_values( + values_buf: &mut MutableBuffer, + num_rows: &mut usize, + array: &ArrayRef, + byte_width: usize, + filter_nulls: bool, +) -> Result<()> { + let needs_filter = filter_nulls && array.null_count() > 0; + + if needs_filter { + let nulls = array.nulls().unwrap(); + let mask = arrow_array::BooleanArray::from(nulls.inner().clone()); + let filtered = arrow::compute::filter(array, &mask)?; + let fsl = filtered.as_fixed_size_list(); + let values_data = fsl.values().to_data(); + let value_bytes = &values_data.buffers()[0].as_slice()[..fsl.len() * byte_width]; + values_buf.extend_from_slice(value_bytes); + *num_rows += fsl.len(); + } else { + // No nulls: copy raw bytes directly, accounting for child array offset. + let fsl = array.as_fixed_size_list(); + let values = fsl.values(); + let values_data = values.to_data(); + let elem_size = byte_width / fsl.value_length() as usize; + let offset_bytes = values_data.offset() * elem_size; + let total_bytes = fsl.len() * byte_width; + let buf = &values_data.buffers()[0].as_slice()[offset_bytes..offset_bytes + total_bytes]; + values_buf.extend_from_slice(buf); + *num_rows += fsl.len(); + } + Ok(()) +} + +/// Fallback sampling for non-FixedSizeList columns (e.g. multivector List +/// columns). Collects batches and concatenates them. When `is_nullable` is +/// true, filters null rows from each batch. +async fn sample_nullable_fallback( + column: &str, + sample_size_hint: usize, + is_nullable: bool, + mut scan: crate::dataset::scanner::DatasetRecordBatchStream, +) -> Result { + let mut schema = None; + let mut filtered = Vec::new(); + let mut num_non_null: usize = 0; + + while num_non_null < sample_size_hint { + let Some(batch) = scan.next().await else { + break; + }; + let batch = batch?; + let array = get_column_from_batch(&batch, column)?; + if is_nullable && array.logical_null_count() >= array.len() { + continue; + } + schema.get_or_insert_with(|| batch.schema()); + let batch = if is_nullable { + filter_non_null_rows(array, batch)? + } else { + batch + }; + num_non_null += batch.num_rows(); + filtered.push(batch); + } + + let Some(schema) = schema else { + return Err(Error::Index { + message: "No non-null training data found".to_string(), + location: location!(), + }); + }; + let batch = arrow::compute::concat_batches(&schema, &filtered)?; + let num_rows_out = batch.num_rows().min(sample_size_hint); + let batch = batch.slice(0, num_rows_out); + + info!( + "Sample training data (fallback): retrieved {} rows by sampling after filtering out nulls", + num_rows_out + ); + + vector_column_to_fsl(&batch, column) +} + +/// Filter a batch to only include rows where `array` is non-null. +fn filter_non_null_rows(array: ArrayRef, batch: RecordBatch) -> Result { + if let Some(nulls) = array.nulls() { + let mask = arrow_array::BooleanArray::from(nulls.inner().clone()); + Ok(arrow::compute::filter_record_batch(&batch, &mask)?) + } else { + Ok(batch) } } +/// Generate `k` unique sorted random row indices from `[0, num_rows)`. +/// +/// Uses two strategies depending on sparsity: +/// - Sparse (`k * 2 < num_rows`): HashSet rejection sampling, O(k) expected. +/// - Dense: Fisher-Yates partial shuffle, O(num_rows) allocation. +fn generate_random_indices(num_rows: usize, k: usize) -> Vec { + assert!(k <= num_rows); + let mut rng = SmallRng::from_os_rng(); + let mut indices = if k * 2 < num_rows { + let mut set = std::collections::HashSet::with_capacity(k); + while set.len() < k { + set.insert(rng.random_range(0..num_rows as u64)); + } + set.into_iter().collect::>() + } else { + let mut all: Vec = (0..num_rows as u64).collect(); + // Partial Fisher-Yates: only shuffle first k elements. + for i in 0..k { + let j = rng.random_range(i..all.len()); + all.swap(i, j); + } + all.truncate(k); + all + }; + indices.sort_unstable(); + indices +} + /// Generate random ranges to sample from a dataset. /// /// This will return an iterator of ranges that cover the whole dataset. It @@ -493,6 +762,7 @@ mod tests { use super::*; use arrow_array::types::Float32Type; + use lance_arrow::FixedSizeListArrayExt; use lance_datagen::{array, gen_batch, ArrayGeneratorExt, Dimension, RowCount}; use crate::dataset::InsertBuilder; @@ -548,6 +818,121 @@ mod tests { assert_eq!(training_data.len(), 1000); } + #[rstest::rstest] + #[case::f16(arrow::datatypes::DataType::Float16, 2)] + #[case::f32(arrow::datatypes::DataType::Float32, 4)] + #[case::f64(arrow::datatypes::DataType::Float64, 8)] + #[test] + fn test_fsl_values_to_array_roundtrip( + #[case] elem_type: arrow::datatypes::DataType, + #[case] elem_size: usize, + ) { + let dim = 4; + let num_rows = 3; + // Fill with recognizable byte patterns: each element gets its index as bytes. + let num_elems = num_rows * dim; + let values_vec: Vec = (0..num_elems) + .flat_map(|i| { + let mut bytes = vec![0u8; elem_size]; + // Write index into the first bytes (little-endian). + let i_bytes = (i as u32).to_le_bytes(); + bytes[..i_bytes.len().min(elem_size)] + .copy_from_slice(&i_bytes[..i_bytes.len().min(elem_size)]); + bytes + }) + .collect(); + let expected_bytes = values_vec.clone(); + let values_buf = MutableBuffer::from(values_vec); + + let dt = DataType::FixedSizeList( + Arc::new(arrow::datatypes::Field::new("item", elem_type, true)), + dim as i32, + ); + let field = lance_core::datatypes::Field::new_arrow("vec", dt, true).unwrap(); + let fsl = fsl_values_to_array(&field, values_buf, num_rows).unwrap(); + assert_eq!(fsl.len(), num_rows); + assert_eq!(fsl.value_length(), dim as i32); + + // Verify the raw bytes round-tripped correctly. + let out_data = fsl.values().to_data(); + let out_bytes = out_data.buffers()[0].as_slice(); + assert_eq!(&out_bytes[..expected_bytes.len()], &expected_bytes[..]); + } + + #[rstest::rstest] + #[case::f32_nullable(array::rand_vec::(Dimension::from(8)), true)] + #[case::f64_nullable(array::rand_vec::(Dimension::from(8)), true)] + #[case::f32_non_nullable(array::rand_vec::(Dimension::from(8)), false)] + #[case::f64_non_nullable(array::rand_vec::(Dimension::from(8)), false)] + #[tokio::test] + async fn test_maybe_sample_training_data_fsl( + #[case] vec_gen: Box, + #[case] nullable: bool, + ) { + let nrows: usize = 2000; + let dims: u32 = 8; + let sample_size: usize = 500; + + let col_gen = if nullable { + vec_gen.with_random_nulls(0.5) + } else { + vec_gen + }; + let data = gen_batch() + .col("vec", col_gen) + .into_batch_rows(RowCount::from(nrows as u64)) + .unwrap(); + + let dataset = InsertBuilder::new("memory://fsl_sample_test") + .execute(vec![data]) + .await + .unwrap(); + + let training_data = maybe_sample_training_data(&dataset, "vec", sample_size) + .await + .unwrap(); + + assert!(training_data.len() > 0 && training_data.len() <= sample_size); + assert_eq!(training_data.null_count(), 0); + assert_eq!(training_data.value_length(), dims as i32); + } + + #[rstest::rstest] + #[case::sparse(1_000_000, 100)] + #[case::dense(100, 80)] + #[case::exact(100, 100)] + #[test] + fn test_generate_random_indices(#[case] num_rows: usize, #[case] k: usize) { + let indices = generate_random_indices(num_rows, k); + assert_eq!(indices.len(), k); + assert!(indices.windows(2).all(|w| w[0] < w[1])); + assert!(indices.iter().all(|&i| (i as usize) < num_rows)); + } + + #[test] + fn test_accumulate_fsl_values_with_sliced_array() { + let dim = 4usize; + let values: Vec = (0..40).map(|i| i as f32).collect(); + let fsl = FixedSizeListArray::try_new_from_values( + arrow_array::Float32Array::from(values), + dim as i32, + ) + .unwrap(); + let sliced = fsl.slice(3, 4); + + let byte_width = dim * std::mem::size_of::(); + let mut buf = MutableBuffer::new(0); + let mut num_rows = 0usize; + let sliced_ref: ArrayRef = Arc::new(sliced); + accumulate_fsl_values(&mut buf, &mut num_rows, &sliced_ref, byte_width, false).unwrap(); + + assert_eq!(num_rows, 4); + let result: &[f32] = + unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const f32, 4 * dim) }; + let expected: Vec = (12..28).map(|i| i as f32).collect(); + assert_eq!(result, &expected[..]); + } + #[tokio::test] async fn test_estimate_multivector_vectors_per_row_fallback_1030() { let nrows: usize = 256;