diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 8b1a000fb1b..3358f9093d5 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -3,12 +3,12 @@ use std::sync::Arc; -use arrow_array::{cast::AsArray, ArrayRef, FixedSizeListArray, RecordBatch}; +use arrow_array::{cast::AsArray, Array, ArrayRef, FixedSizeListArray, RecordBatch}; use futures::StreamExt; use lance_arrow::{interleave_batches, DataTypeExt}; use lance_core::datatypes::Schema; use lance_linalg::distance::DistanceType; -use log::info; +use log::{info, warn}; use rand::rngs::SmallRng; use rand::seq::{IteratorRandom, SliceRandom}; use rand::SeedableRng; @@ -84,6 +84,58 @@ fn get_column_from_batch(batch: &RecordBatch, column: &str) -> Result Ok(current_array) } +async fn estimate_multivector_vectors_per_row( + dataset: &Dataset, + column: &str, + num_rows: usize, +) -> Result { + if num_rows == 0 { + return Ok(1030); + } + + let projection = dataset.schema().project(&[column])?; + + // Try a few random samples first (fast path). + let sample_batch_size = std::cmp::min(64, num_rows); + for _ in 0..8 { + let batch = dataset.sample(sample_batch_size, &projection).await?; + let array = get_column_from_batch(&batch, column)?; + let list_array = array.as_list::(); + for i in 0..list_array.len() { + if list_array.is_null(i) { + continue; + } + let len = list_array.value_length(i) as usize; + if len > 0 { + return Ok(len); + } + } + } + + // Fallback: scan a small prefix to find a non-null example. This avoids rare + // flakiness when values are extremely sparse. + let mut scanner = dataset.scan(); + scanner.project(&[column])?; + let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; + scanner.filter_expr(column_expr.is_not_null()); + scanner.limit(Some(std::cmp::min(num_rows, 1024) as i64), None)?; + let batch = scanner.try_into_batch().await?; + let array = get_column_from_batch(&batch, column)?; + let list_array = array.as_list::(); + for i in 0..list_array.len() { + let len = list_array.value_length(i) as usize; + if len > 0 { + return Ok(len); + } + } + + warn!( + "Could not find a non-empty multivector value for column {}, falling back to n=1030", + column + ); + Ok(1030) +} + /// Get the vector dimension of the given column in the schema. pub fn get_vector_dim(schema: &Schema, column: &str) -> Result { let field = schema.field(column).ok_or(Error::Index { @@ -231,11 +283,12 @@ pub async fn maybe_sample_training_data( arrow::datatypes::DataType::List(_) => { // for multivector, we need `sample_size_hint` vectors for training, // but each multivector is a list of vectors, but we don't know how many - // vectors are in each multivector. For now we just assume there are 1030 vectors - // in each multivector (Copali case). + // vectors are in each multivector. Estimate this by looking at a non-null row. // Set a minimum sample size of 128 to avoid too small samples, // it's not a problem because 128 multivectors is just about 64 MiB - sample_size_hint.div_ceil(1030).max(128) + let vectors_per_row = + estimate_multivector_vectors_per_row(dataset, column, num_rows).await?; + sample_size_hint.div_ceil(vectors_per_row).max(128) } _ => sample_size_hint, }; @@ -439,6 +492,11 @@ fn random_ranges( mod tests { use super::*; + use arrow_array::types::Float32Type; + use lance_datagen::{array, gen_batch, ArrayGeneratorExt, Dimension, RowCount}; + + use crate::dataset::InsertBuilder; + #[rstest::rstest] #[test] fn test_random_ranges( @@ -461,4 +519,60 @@ mod tests { }); assert_eq!(ranges, expected.collect::>()); } + + #[tokio::test] + async fn test_maybe_sample_training_data_multivector_infers_vectors_per_row() { + let nrows: usize = 2000; + let dims: u32 = 8; + let vectors_per_row: u32 = 2; + + let mv = array::cycle_vec_var( + array::rand_vec::(Dimension::from(dims)), + Dimension::from(vectors_per_row), + Dimension::from(vectors_per_row + 1), + ); + + let data = gen_batch() + .col("mv", mv) + .into_batch_rows(RowCount::from(nrows as u64)) + .unwrap(); + + let dataset = InsertBuilder::new("memory://") + .execute(vec![data]) + .await + .unwrap(); + + let training_data = maybe_sample_training_data(&dataset, "mv", 1000) + .await + .unwrap(); + assert_eq!(training_data.len(), 1000); + } + + #[tokio::test] + async fn test_estimate_multivector_vectors_per_row_fallback_1030() { + let nrows: usize = 256; + let dims: u32 = 8; + + let mv = array::cycle_vec_var( + array::rand_vec::(Dimension::from(dims)), + Dimension::from(2), + Dimension::from(3), + ) + .with_random_nulls(1.0); + + let data = gen_batch() + .col("mv", mv) + .into_batch_rows(RowCount::from(nrows as u64)) + .unwrap(); + + let dataset = InsertBuilder::new("memory://") + .execute(vec![data]) + .await + .unwrap(); + + let n = estimate_multivector_vectors_per_row(&dataset, "mv", nrows) + .await + .unwrap(); + assert_eq!(n, 1030); + } }