Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 119 additions & 5 deletions rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

use std::sync::Arc;

use arrow_array::{cast::AsArray, ArrayRef, FixedSizeListArray, RecordBatch};
use arrow_array::{cast::AsArray, Array, ArrayRef, FixedSizeListArray, RecordBatch};
use futures::StreamExt;
use lance_arrow::{interleave_batches, DataTypeExt};
use lance_core::datatypes::Schema;
use lance_linalg::distance::DistanceType;
use log::info;
use log::{info, warn};
use rand::rngs::SmallRng;
use rand::seq::{IteratorRandom, SliceRandom};
use rand::SeedableRng;
Expand Down Expand Up @@ -84,6 +84,58 @@ fn get_column_from_batch(batch: &RecordBatch, column: &str) -> Result<ArrayRef>
Ok(current_array)
}

async fn estimate_multivector_vectors_per_row(
dataset: &Dataset,
column: &str,
num_rows: usize,
) -> Result<usize> {
if num_rows == 0 {
return Ok(1030);
}

let projection = dataset.schema().project(&[column])?;

// Try a few random samples first (fast path).
let sample_batch_size = std::cmp::min(64, num_rows);
for _ in 0..8 {
let batch = dataset.sample(sample_batch_size, &projection).await?;
let array = get_column_from_batch(&batch, column)?;
let list_array = array.as_list::<i32>();
for i in 0..list_array.len() {
if list_array.is_null(i) {
continue;
}
let len = list_array.value_length(i) as usize;
if len > 0 {
return Ok(len);
}
}
}

// Fallback: scan a small prefix to find a non-null example. This avoids rare
// flakiness when values are extremely sparse.
let mut scanner = dataset.scan();
scanner.project(&[column])?;
let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?;
scanner.filter_expr(column_expr.is_not_null());
scanner.limit(Some(std::cmp::min(num_rows, 1024) as i64), None)?;
let batch = scanner.try_into_batch().await?;
let array = get_column_from_batch(&batch, column)?;
let list_array = array.as_list::<i32>();
for i in 0..list_array.len() {
let len = list_array.value_length(i) as usize;
if len > 0 {
return Ok(len);
}
}

warn!(
"Could not find a non-empty multivector value for column {}, falling back to n=1030",
column
);
Ok(1030)
}

/// Get the vector dimension of the given column in the schema.
pub fn get_vector_dim(schema: &Schema, column: &str) -> Result<usize> {
let field = schema.field(column).ok_or(Error::Index {
Expand Down Expand Up @@ -231,11 +283,12 @@ pub async fn maybe_sample_training_data(
arrow::datatypes::DataType::List(_) => {
// for multivector, we need `sample_size_hint` vectors for training,
// but each multivector is a list of vectors, but we don't know how many
// vectors are in each multivector. For now we just assume there are 1030 vectors
// in each multivector (Copali case).
// vectors are in each multivector. Estimate this by looking at a non-null row.
// Set a minimum sample size of 128 to avoid too small samples,
// it's not a problem because 128 multivectors is just about 64 MiB
sample_size_hint.div_ceil(1030).max(128)
let vectors_per_row =
estimate_multivector_vectors_per_row(dataset, column, num_rows).await?;
sample_size_hint.div_ceil(vectors_per_row).max(128)
}
_ => sample_size_hint,
};
Expand Down Expand Up @@ -439,6 +492,11 @@ fn random_ranges(
mod tests {
use super::*;

use arrow_array::types::Float32Type;
use lance_datagen::{array, gen_batch, ArrayGeneratorExt, Dimension, RowCount};

use crate::dataset::InsertBuilder;

#[rstest::rstest]
#[test]
fn test_random_ranges(
Expand All @@ -461,4 +519,60 @@ mod tests {
});
assert_eq!(ranges, expected.collect::<Vec<_>>());
}

#[tokio::test]
async fn test_maybe_sample_training_data_multivector_infers_vectors_per_row() {
let nrows: usize = 2000;
let dims: u32 = 8;
let vectors_per_row: u32 = 2;

let mv = array::cycle_vec_var(
array::rand_vec::<Float32Type>(Dimension::from(dims)),
Dimension::from(vectors_per_row),
Dimension::from(vectors_per_row + 1),
);

let data = gen_batch()
.col("mv", mv)
.into_batch_rows(RowCount::from(nrows as u64))
.unwrap();

let dataset = InsertBuilder::new("memory://")
.execute(vec![data])
.await
.unwrap();

let training_data = maybe_sample_training_data(&dataset, "mv", 1000)
.await
.unwrap();
assert_eq!(training_data.len(), 1000);
}

#[tokio::test]
async fn test_estimate_multivector_vectors_per_row_fallback_1030() {
let nrows: usize = 256;
let dims: u32 = 8;

let mv = array::cycle_vec_var(
array::rand_vec::<Float32Type>(Dimension::from(dims)),
Dimension::from(2),
Dimension::from(3),
)
.with_random_nulls(1.0);

let data = gen_batch()
.col("mv", mv)
.into_batch_rows(RowCount::from(nrows as u64))
.unwrap();

let dataset = InsertBuilder::new("memory://")
.execute(vec![data])
.await
.unwrap();

let n = estimate_multivector_vectors_per_row(&dataset, "mv", nrows)
.await
.unwrap();
assert_eq!(n, 1030);
}
}
Loading