diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 3bdc707bd63..e103f6746c7 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -1271,7 +1271,7 @@ impl IvfIndexBuilder // get top REASSIGN_RANGE centroids from c0 let (reassign_part_ids, reassign_part_centroids) = - self.select_reassign_candidates(ivf, &c0)?; + self.select_reassign_candidates(ivf, part_idx, &c0)?; // compute the distance between the vectors and the 3 centroids (original one and the 2 new ones) let d0 = self.distance_type.arrow_batch_func()(&c0, vectors)?; @@ -1495,7 +1495,7 @@ impl IvfIndexBuilder // get top REASSIGN_RANGE centroids from c0 let (reassign_part_ids, reassign_part_centroids) = - self.select_reassign_candidates(ivf, &c0)?; + self.select_reassign_candidates(ivf, part_idx, &c0)?; let new_part_id = |idx: usize| -> usize { if idx < part_idx { @@ -1680,23 +1680,11 @@ impl IvfIndexBuilder fn select_reassign_candidates( &self, ivf: &IvfModel, + part_idx: usize, c0: &ArrayRef, ) -> Result<(UInt32Array, FixedSizeListArray)> { - let reassign_range = std::cmp::min(REASSIGN_RANGE + 1, ivf.num_partitions()); - let centroids = ivf.centroids_array().unwrap(); - let centroid_dists = self.distance_type.arrow_batch_func()(&c0, centroids)?; - let reassign_range_candidates = - sort_to_indices(centroid_dists.as_ref(), None, Some(reassign_range))?; - // exclude the original centroid itself - let reassign_candidate_ids = &reassign_range_candidates.slice(1, reassign_range - 1); - let reassign_candidate_centroids = - arrow::compute::take(centroids, reassign_candidate_ids, None)?; - Ok(( - reassign_candidate_ids.clone(), - reassign_candidate_centroids.as_fixed_size_list().clone(), - )) + select_reassign_candidates_impl(self.distance_type, ivf, part_idx, c0) } - // assign the vectors of original partition #[allow(clippy::too_many_arguments)] fn assign_vectors( @@ -1797,6 +1785,34 @@ impl IvfIndexBuilder } } +fn select_reassign_candidates_impl( + distance_type: DistanceType, + ivf: &IvfModel, + part_idx: usize, + c0: &ArrayRef, +) -> Result<(UInt32Array, FixedSizeListArray)> { + let reassign_range = std::cmp::min(REASSIGN_RANGE + 1, ivf.num_partitions()); + let centroids = ivf.centroids_array().unwrap(); + let centroid_dists = distance_type.arrow_batch_func()(&c0, centroids)?; + let reassign_range_candidates = + sort_to_indices(centroid_dists.as_ref(), None, Some(reassign_range))?; + let selection_len = reassign_range.saturating_sub(1); + let filtered_ids = reassign_range_candidates + .values() + .iter() + .copied() + .filter(|&idx| idx as usize != part_idx) + .take(selection_len) + .collect::>(); + let reassign_candidate_ids = UInt32Array::from(filtered_ids); + let reassign_candidate_centroids = + arrow::compute::take(centroids, &reassign_candidate_ids, None)?; + Ok(( + reassign_candidate_ids, + reassign_candidate_centroids.as_fixed_size_list().clone(), + )) +} + struct AssignResult { // the batches of new vectors that are assigned to the partition, // and the deleted row ids @@ -1837,3 +1853,37 @@ pub(crate) fn index_type_string(sub_index: SubIndexType, quantizer: Quantization } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::Float32Array; + + #[test] + fn select_reassign_candidates_skips_deleted_partition() { + let dim = 4; + let centroid_values = Float32Array::from(vec![0.0_f32; dim * 2]); + let centroids = + FixedSizeListArray::try_new_from_values(centroid_values, dim as i32).unwrap(); + let mut ivf = IvfModel::new(centroids, None); + ivf.lengths = vec![10, 20]; + ivf.offsets = vec![0, 10]; + + let c0 = ivf.centroid(1).unwrap(); + let (reassign_ids, reassign_centroids) = + select_reassign_candidates_impl(DistanceType::L2, &ivf, 1, &c0).unwrap(); + + assert_eq!(reassign_ids.len(), 1); + assert_eq!(reassign_ids.value(0), 0); + assert_eq!(reassign_centroids.len(), 1); + + let expected_centroid = ivf.centroid(0).unwrap(); + assert_eq!( + reassign_centroids + .value(0) + .as_primitive::() + .values(), + expected_centroid.as_primitive::().values() + ); + } +}