Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions rust/lance-linalg/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,60 @@ pub fn normalize_fsl(fsl: &FixedSizeListArray) -> Result<FixedSizeListArray> {
}
}

fn do_normalize_fsl_inplace<T: ArrowPrimitiveType>(
fsl: FixedSizeListArray,
) -> Result<FixedSizeListArray>
where
T::Native: Float + Sum + AsPrimitive<f32>,
{
let dim = fsl.value_length() as usize;
let (field, size, values_array, nulls) = fsl.into_parts();

// Clone the PrimitiveArray (shares the underlying buffer), then drop the
// Arc<dyn Array> so the buffer's refcount drops to 1.
let prim = values_array
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.expect("values must be PrimitiveArray")
.clone();
drop(values_array);

// into_builder gives mutable access when the buffer is uniquely owned,
// avoiding a full copy of the (potentially multi-GB) training data.
match prim.into_builder() {
Ok(mut builder) => {
for chunk in builder.values_slice_mut().chunks_mut(dim) {
let l2_norm = chunk.iter().map(|x| x.powi(2)).sum::<T::Native>().sqrt();
for x in chunk.iter_mut() {
*x = *x / l2_norm;
}
}
FixedSizeListArray::try_new(field, size, Arc::new(builder.finish()), nulls)
}
Err(prim) => {
let fsl = FixedSizeListArray::try_new(field, size, Arc::new(prim), nulls)?;
do_normalize_fsl::<T>(&fsl)
}
}
}

/// L2 normalize a [FixedSizeListArray] (of vectors), attempting in-place mutation.
///
/// If the underlying buffer is uniquely owned, normalization is performed in-place
/// to avoid allocating a second copy. Otherwise falls back to the copy path used
/// by [`normalize_fsl`].
pub fn normalize_fsl_owned(fsl: FixedSizeListArray) -> Result<FixedSizeListArray> {
match fsl.value_type() {
DataType::Float16 => do_normalize_fsl_inplace::<Float16Type>(fsl),
DataType::Float32 => do_normalize_fsl_inplace::<Float32Type>(fsl),
DataType::Float64 => do_normalize_fsl_inplace::<Float64Type>(fsl),
_ => Err(ArrowError::SchemaError(format!(
"Normalize only supports float array, got: {}",
fsl.value_type()
))),
}
}

fn hash_numeric_type<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Result<UInt64Array>
where
T::Native: Hash,
Expand Down Expand Up @@ -451,4 +505,106 @@ mod tests {
assert_relative_eq!(values.value(2), 0.0);
assert_relative_eq!(values.value(3), 1.0);
}

fn make_fsl(values: &[f32], dim: i32) -> FixedSizeListArray {
let field = Arc::new(Field::new("item", DataType::Float32, true));
FixedSizeListArray::try_new(
field,
dim,
Arc::new(Float32Array::from_iter_values(values.iter().copied())),
None,
)
.unwrap()
}

/// Assert FSL values match expected, where None means NaN.
fn assert_fsl_eq(actual: &FixedSizeListArray, expected: &[Option<f32>], label: &str) {
let vals = actual.values().as_primitive::<Float32Type>();
assert_eq!(vals.len(), expected.len(), "{label}: length mismatch");
for (i, exp) in expected.iter().enumerate() {
match exp {
None => assert!(vals.value(i).is_nan(), "{label}[{i}]: expected NaN"),
Some(v) => assert_relative_eq!(vals.value(i), *v, epsilon = 1e-6),
}
}
}

/// normalize_fsl_owned produces correct values and matches normalize_fsl.
/// Zero vectors yield NaN (cosine is undefined; downstream is_finite filters them).
#[test]
fn test_normalize_fsl_owned_values() {
#[allow(clippy::type_complexity)]
let cases: &[(&str, &[f32], &[Option<f32>])] = &[
(
"basic",
&[3.0, 4.0, 5.0, 12.0],
&[Some(0.6), Some(0.8), Some(5.0 / 13.0), Some(12.0 / 13.0)],
),
(
"zero_vector",
&[3.0, 4.0, 0.0, 0.0, 5.0, 12.0],
&[
Some(0.6),
Some(0.8),
None,
None,
Some(5.0 / 13.0),
Some(12.0 / 13.0),
],
),
];
for (name, input, expected) in cases {
let fsl = make_fsl(input, 2);
assert_fsl_eq(&normalize_fsl(&fsl).unwrap(), expected, name);
assert_fsl_eq(&normalize_fsl_owned(fsl).unwrap(), expected, name);
}
}

/// Uniquely-owned buffer is mutated in-place (no copy).
#[test]
fn test_normalize_fsl_owned_inplace() {
let fsl = make_fsl(&[3.0, 4.0, 5.0, 12.0], 2);
let ptr = fsl.values().as_primitive::<Float32Type>().values().as_ptr();
let result = normalize_fsl_owned(fsl).unwrap();
let new_ptr = result
.values()
.as_primitive::<Float32Type>()
.values()
.as_ptr();
assert_eq!(ptr, new_ptr, "expected in-place mutation");
}

/// Sliced inputs normalize correctly via the by-reference path.
/// (normalize_fsl_owned uses into_builder which does not support sliced
/// arrays; use normalize_fsl for sliced data.)
#[test]
fn test_normalize_fsl_sliced_input() {
let sliced = {
let fsl = make_fsl(&[1.0, 0.0, 0.0, 1.0, 3.0, 4.0], 2);
fsl.slice(1, 2)
};

let expected = &[Some(0.0), Some(1.0), Some(0.6), Some(0.8)];
assert_fsl_eq(&normalize_fsl(&sliced).unwrap(), expected, "sliced_ref");
}

/// Shared buffer falls back to copy path and still produces correct values.
#[test]
fn test_normalize_fsl_owned_shared_buffer_fallback() {
let fsl = make_fsl(&[3.0, 4.0, 5.0, 12.0], 2);
let _hold = fsl.clone(); // force shared buffer
let expected = &[Some(0.6), Some(0.8), Some(5.0 / 13.0), Some(12.0 / 13.0)];
assert_fsl_eq(&normalize_fsl_owned(fsl).unwrap(), expected, "fallback");
}

/// Null buffer is preserved through normalization.
#[test]
fn test_normalize_fsl_owned_preserves_nulls() {
let values = Float32Array::from_iter_values([3.0, 4.0, 0.0, 0.0, 5.0, 12.0]);
let nulls = NullBuffer::from(vec![true, false, true]);
let field = Arc::new(Field::new("item", DataType::Float32, true));
let fsl =
FixedSizeListArray::try_new(field, 2, Arc::new(values), Some(nulls.clone())).unwrap();
assert_eq!(normalize_fsl_owned(fsl).unwrap().nulls(), Some(&nulls));
}
}
2 changes: 1 addition & 1 deletion rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3371,7 +3371,7 @@ impl Scanner {
}
// No index: flat search all target fragments
let flat_match_plan = self
.plan_flat_match_query(target_fragments.to_vec(), query, params, filter_plan)
.plan_flat_match_query(target_fragments.clone(), query, params, filter_plan)
.await?;
(None, Some(flat_match_plan))
}
Expand Down
10 changes: 4 additions & 6 deletions rust/lance/src/index/vector/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ use lance_index::vector::quantizer::{QuantizerMetadata, QuantizerStorage};
use lance_index::vector::shared::{write_unified_ivf_and_index_metadata, SupportedIvfIndexType};
use lance_index::vector::storage::STORAGE_METADATA_KEY;
use lance_index::vector::transform::Flatten;
use lance_index::vector::utils::is_finite;
use lance_index::vector::v3::shuffler::{EmptyReader, IvfShufflerReader};
use lance_index::vector::v3::subindex::SubIndexType;
use lance_index::vector::{ivf::storage::IvfModel, PART_ID_FIELD};
Expand Down Expand Up @@ -453,14 +452,13 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
// If metric type is cosine, normalize the training data, and after this point,
// treat the metric type as L2.
let training_data = if self.distance_type == DistanceType::Cosine {
lance_linalg::kernels::normalize_fsl(&training_data)?
lance_linalg::kernels::normalize_fsl_owned(training_data)?
} else {
training_data
};

// we filtered out nulls when sampling, but we still need to filter out NaNs and INFs here
let training_data = arrow::compute::filter(&training_data, &is_finite(&training_data))?;
let training_data = training_data.as_fixed_size_list();
let training_data = utils::filter_finite_training_data(training_data)?;

let training_data = match (self.ivf.as_ref(), Q::use_residual(self.distance_type)) {
(Some(ivf), true) => {
Expand All @@ -470,9 +468,9 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
vec![],
);
span!(Level::INFO, "compute residual for PQ training")
.in_scope(|| ivf_transformer.compute_residual(training_data))?
.in_scope(|| ivf_transformer.compute_residual(&training_data))?
}
_ => training_data.clone(),
_ => training_data,
};

info!("Start to train quantizer");
Expand Down
12 changes: 5 additions & 7 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use super::{builder::IvfIndexBuilder, utils::PartitionLoadLock};
use super::{
pq::{build_pq_model, PQIndex},
utils::maybe_sample_training_data,
utils::{filter_finite_training_data, maybe_sample_training_data},
};
use crate::index::vector::utils::{get_vector_dim, get_vector_type};
use crate::index::DatasetIndexInternalExt;
Expand Down Expand Up @@ -57,7 +57,6 @@ use lance_index::vector::ivf::storage::{IvfModel, IVF_METADATA_KEY};
use lance_index::vector::kmeans::KMeansParams;
use lance_index::vector::pq::storage::transpose;
use lance_index::vector::quantizer::QuantizationType;
use lance_index::vector::utils::is_finite;
use lance_index::vector::v3::shuffler::IvfShuffler;
use lance_index::vector::v3::subindex::{IvfSubIndex, SubIndexType};
use lance_index::vector::DISTANCE_TYPE_KEY;
Expand Down Expand Up @@ -86,7 +85,7 @@ use lance_io::{
traits::{Reader, WriteExt, Writer},
};
use lance_linalg::distance::{DistanceType, Dot, MetricType, L2};
use lance_linalg::{distance::Normalize, kernels::normalize_fsl};
use lance_linalg::{distance::Normalize, kernels::normalize_fsl_owned};
use log::{info, warn};
use object_store::path::Path;
use prost::Message;
Expand Down Expand Up @@ -1271,19 +1270,18 @@ pub async fn build_ivf_model(
// If metric type is cosine, normalize the training data, and after this point,
// treat the metric type as L2.
let (training_data, mt) = if metric_type == MetricType::Cosine {
let training_data = normalize_fsl(&training_data)?;
let training_data = normalize_fsl_owned(training_data)?;
(training_data, MetricType::L2)
} else {
(training_data, metric_type)
};

// we filtered out nulls when sampling, but we still need to filter out NaNs and INFs here
let training_data = arrow::compute::filter(&training_data, &is_finite(&training_data))?;
let training_data = training_data.as_fixed_size_list();
let training_data = filter_finite_training_data(training_data)?;

info!("Start to train IVF model");
let start = std::time::Instant::now();
let ivf = train_ivf_model(centroids, training_data, mt, params, progress).await?;
let ivf = train_ivf_model(centroids, &training_data, mt, params, progress).await?;
info!(
"Trained IVF model in {:02} seconds",
start.elapsed().as_secs_f32()
Expand Down
5 changes: 3 additions & 2 deletions rust/lance/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use snafu::location;
use tracing::{instrument, span, Level};
// Re-export
pub use lance_index::vector::pq::PQBuildParams;
use lance_linalg::kernels::normalize_fsl;
use lance_linalg::kernels::normalize_fsl_owned;

use super::VectorIndex;
use crate::index::prefilter::PreFilter;
Expand Down Expand Up @@ -561,7 +561,7 @@ pub async fn build_pq_model(

if metric_type == MetricType::Cosine {
info!("Normalize training data for PQ training: Cosine");
training_data = normalize_fsl(&training_data)?;
training_data = normalize_fsl_owned(training_data)?;
}

let training_data = if let Some(ivf) = ivf {
Expand Down Expand Up @@ -638,6 +638,7 @@ mod tests {
use arrow_array::RecordBatchIterator;
use arrow_schema::{Field, Schema};
use lance_core::utils::tempfile::TempStrDir;
use lance_linalg::kernels::normalize_fsl;

use crate::index::vector::ivf::build_ivf_model;
use lance_core::utils::mask::RowAddrMask;
Expand Down
37 changes: 36 additions & 1 deletion rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,21 @@ pub async fn maybe_sample_training_data(
}
}

/// Filter out non-finite vectors from sampled training data.
///
/// This is a no-op when all rows are finite, avoiding an unnecessary copy.
pub fn filter_finite_training_data(
training_data: FixedSizeListArray,
) -> Result<FixedSizeListArray> {
let finite_mask = lance_index::vector::utils::is_finite(&training_data);
if finite_mask.true_count() == training_data.len() {
Ok(training_data)
} else {
let filtered = arrow::compute::filter(&training_data, &finite_mask)?;
Ok(filtered.as_fixed_size_list().clone())
}
}

#[derive(Debug)]
pub struct PartitionLoadLock {
partition_locks: Vec<Arc<Mutex<()>>>,
Expand Down Expand Up @@ -761,7 +776,8 @@ fn random_ranges(
mod tests {
use super::*;

use arrow_array::types::Float32Type;
use arrow_array::{types::Float32Type, Float32Array};
use arrow_schema::{DataType, Field};
use lance_arrow::FixedSizeListArrayExt;
use lance_datagen::{array, gen_batch, ArrayGeneratorExt, Dimension, RowCount};

Expand Down Expand Up @@ -933,6 +949,25 @@ mod tests {
assert_eq!(result, &expected[..]);
}

#[test]
fn test_filter_finite_training_data() {
let values = Float32Array::from_iter_values([
1.0,
2.0, // finite
f32::NAN,
0.0, // non-finite
3.0,
4.0, // finite
]);
let field = Arc::new(Field::new("item", DataType::Float32, true));
let training_data = FixedSizeListArray::try_new(field, 2, Arc::new(values), None).unwrap();

let filtered = filter_finite_training_data(training_data).unwrap();
assert_eq!(filtered.len(), 2);
let vals = filtered.values().as_primitive::<Float32Type>();
assert_eq!(vals.values(), &[1.0, 2.0, 3.0, 4.0]);
}

#[tokio::test]
async fn test_estimate_multivector_vectors_per_row_fallback_1030() {
let nrows: usize = 256;
Expand Down