From 0503804b4904cfcaf3a048168ec9d031cb8e8c39 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 18 Dec 2025 22:27:08 +0800 Subject: [PATCH 1/3] Infer multivector sample rows --- rust/lance/src/index/vector/utils.rs | 94 ++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 5 deletions(-) diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 8b1a000fb1b..f78bbcdc5c5 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -3,12 +3,12 @@ use std::sync::Arc; -use arrow_array::{cast::AsArray, ArrayRef, FixedSizeListArray, RecordBatch}; +use arrow_array::{cast::AsArray, Array, ArrayRef, FixedSizeListArray, RecordBatch}; use futures::StreamExt; use lance_arrow::{interleave_batches, DataTypeExt}; use lance_core::datatypes::Schema; use lance_linalg::distance::DistanceType; -use log::info; +use log::{info, warn}; use rand::rngs::SmallRng; use rand::seq::{IteratorRandom, SliceRandom}; use rand::SeedableRng; @@ -84,6 +84,58 @@ fn get_column_from_batch(batch: &RecordBatch, column: &str) -> Result Ok(current_array) } +async fn estimate_multivector_vectors_per_row( + dataset: &Dataset, + column: &str, + num_rows: usize, +) -> Result { + if num_rows == 0 { + return Ok(1); + } + + let projection = dataset.schema().project(&[column])?; + + // Try a few random samples first (fast path). + let sample_batch_size = std::cmp::min(64, num_rows); + for _ in 0..8 { + let batch = dataset.sample(sample_batch_size, &projection).await?; + let array = get_column_from_batch(&batch, column)?; + let list_array = array.as_list::(); + for i in 0..list_array.len() { + if list_array.is_null(i) { + continue; + } + let len = list_array.value_length(i) as usize; + if len > 0 { + return Ok(len); + } + } + } + + // Fallback: scan a small prefix to find a non-null example. This avoids rare + // flakiness when values are extremely sparse. + let mut scanner = dataset.scan(); + scanner.project(&[column])?; + let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; + scanner.filter_expr(column_expr.is_not_null()); + scanner.limit(Some(std::cmp::min(num_rows, 1024) as i64), None)?; + let batch = scanner.try_into_batch().await?; + let array = get_column_from_batch(&batch, column)?; + let list_array = array.as_list::(); + for i in 0..list_array.len() { + let len = list_array.value_length(i) as usize; + if len > 0 { + return Ok(len); + } + } + + warn!( + "Could not find a non-empty multivector value for column {}, falling back to n=1", + column + ); + Ok(1) +} + /// Get the vector dimension of the given column in the schema. pub fn get_vector_dim(schema: &Schema, column: &str) -> Result { let field = schema.field(column).ok_or(Error::Index { @@ -231,11 +283,12 @@ pub async fn maybe_sample_training_data( arrow::datatypes::DataType::List(_) => { // for multivector, we need `sample_size_hint` vectors for training, // but each multivector is a list of vectors, but we don't know how many - // vectors are in each multivector. For now we just assume there are 1030 vectors - // in each multivector (Copali case). + // vectors are in each multivector. Estimate this by looking at a non-null row. // Set a minimum sample size of 128 to avoid too small samples, // it's not a problem because 128 multivectors is just about 64 MiB - sample_size_hint.div_ceil(1030).max(128) + let vectors_per_row = + estimate_multivector_vectors_per_row(dataset, column, num_rows).await?; + sample_size_hint.div_ceil(vectors_per_row).max(128) } _ => sample_size_hint, }; @@ -439,6 +492,11 @@ fn random_ranges( mod tests { use super::*; + use arrow_array::types::Float32Type; + use lance_datagen::{array, gen_batch, Dimension, RowCount}; + + use crate::dataset::InsertBuilder; + #[rstest::rstest] #[test] fn test_random_ranges( @@ -461,4 +519,30 @@ mod tests { }); assert_eq!(ranges, expected.collect::>()); } + + #[tokio::test] + async fn test_maybe_sample_training_data_multivector_infers_vectors_per_row() { + let nrows: usize = 2000; + let dims: u32 = 8; + let vectors_per_row: u32 = 2; + + let mv = array::cycle_vec_var( + array::rand_vec::(Dimension::from(dims)), + Dimension::from(vectors_per_row), + Dimension::from(vectors_per_row + 1), + ); + + let data = gen_batch() + .col("mv", mv) + .into_batch_rows(RowCount::from(nrows as u64)) + .unwrap(); + + let dataset = InsertBuilder::new("memory://") + .execute(vec![data]) + .await + .unwrap(); + + let training_data = maybe_sample_training_data(&dataset, "mv", 1000).await.unwrap(); + assert_eq!(training_data.len(), 1000); + } } From 7ac46bf1316a238860b8e57af010c5beff7ea523 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 18 Dec 2025 22:38:56 +0800 Subject: [PATCH 2/3] Fallback multivector n=1030 --- rust/lance/src/index/vector/utils.rs | 36 ++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index f78bbcdc5c5..ae4198ebeb8 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -90,7 +90,7 @@ async fn estimate_multivector_vectors_per_row( num_rows: usize, ) -> Result { if num_rows == 0 { - return Ok(1); + return Ok(1030); } let projection = dataset.schema().project(&[column])?; @@ -130,10 +130,10 @@ async fn estimate_multivector_vectors_per_row( } warn!( - "Could not find a non-empty multivector value for column {}, falling back to n=1", + "Could not find a non-empty multivector value for column {}, falling back to n=1030", column ); - Ok(1) + Ok(1030) } /// Get the vector dimension of the given column in the schema. @@ -493,7 +493,7 @@ mod tests { use super::*; use arrow_array::types::Float32Type; - use lance_datagen::{array, gen_batch, Dimension, RowCount}; + use lance_datagen::{array, gen_batch, ArrayGeneratorExt, Dimension, RowCount}; use crate::dataset::InsertBuilder; @@ -545,4 +545,32 @@ mod tests { let training_data = maybe_sample_training_data(&dataset, "mv", 1000).await.unwrap(); assert_eq!(training_data.len(), 1000); } + + #[tokio::test] + async fn test_estimate_multivector_vectors_per_row_fallback_1030() { + let nrows: usize = 256; + let dims: u32 = 8; + + let mv = array::cycle_vec_var( + array::rand_vec::(Dimension::from(dims)), + Dimension::from(2), + Dimension::from(3), + ) + .with_random_nulls(1.0); + + let data = gen_batch() + .col("mv", mv) + .into_batch_rows(RowCount::from(nrows as u64)) + .unwrap(); + + let dataset = InsertBuilder::new("memory://") + .execute(vec![data]) + .await + .unwrap(); + + let n = estimate_multivector_vectors_per_row(&dataset, "mv", nrows) + .await + .unwrap(); + assert_eq!(n, 1030); + } } From b1996b74e51d44fd6a9348a8930b4ca974833bf1 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 18 Dec 2025 22:58:51 +0800 Subject: [PATCH 3/3] fmt: vector utils --- rust/lance/src/index/vector/utils.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index ae4198ebeb8..3358f9093d5 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -542,7 +542,9 @@ mod tests { .await .unwrap(); - let training_data = maybe_sample_training_data(&dataset, "mv", 1000).await.unwrap(); + let training_data = maybe_sample_training_data(&dataset, "mv", 1000) + .await + .unwrap(); assert_eq!(training_data.len(), 1000); }