diff --git a/java/lance-jni/src/vector_trainer.rs b/java/lance-jni/src/vector_trainer.rs index e2d6012859e..97611ae4acb 100755 --- a/java/lance-jni/src/vector_trainer.rs +++ b/java/lance-jni/src/vector_trainer.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::sync::Arc; + use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; @@ -11,6 +13,7 @@ use jni::objects::{JClass, JFloatArray, JObject, JString}; use jni::sys::jfloatArray; use jni::JNIEnv; use lance::index::vector::utils::get_vector_dim; +use lance::index::NoopIndexBuildProgress; use lance_index::vector::ivf::builder::IvfBuildParams as RustIvfBuildParams; use lance_index::vector::pq::builder::PQBuildParams as RustPQBuildParams; use lance_linalg::distance::MetricType; @@ -114,6 +117,7 @@ fn inner_train_ivf_centroids<'local>( dim, metric_type, &ivf_params, + Arc::new(NoopIndexBuildProgress), ))?; let centroids = ivf_model diff --git a/python/src/indices.rs b/python/src/indices.rs index 412f329d83a..060d4e10fdd 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::collections::HashSet; +use std::sync::Arc; use arrow::pyarrow::{PyArrowType, ToPyArrow}; use arrow_array::{Array, FixedSizeListArray}; @@ -10,6 +11,7 @@ use chrono::{DateTime, Utc}; use lance::dataset::Dataset as LanceDataset; use lance::index::vector::ivf::builder::write_vector_storage; use lance::io::ObjectStore; +use lance_index::progress::NoopIndexBuildProgress; use lance_index::vector::ivf::shuffler::{shuffle_vectors, IvfShuffler}; use lance_index::vector::{ ivf::{storage::IvfModel, IvfBuildParams}, @@ -141,6 +143,7 @@ async fn do_train_ivf_model( dimension, distance_type, ¶ms, + Arc::new(NoopIndexBuildProgress), ) .await .infer_error()?; diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 776619e5036..7506fd52e6b 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -26,6 +26,7 @@ pub mod mem_wal; pub mod metrics; pub mod optimize; pub mod prefilter; +pub mod progress; pub mod registry; pub mod scalar; pub mod traits; diff --git a/rust/lance-index/src/progress.rs b/rust/lance-index/src/progress.rs new file mode 100644 index 00000000000..4ac664c7623 --- /dev/null +++ b/rust/lance-index/src/progress.rs @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use async_trait::async_trait; +use lance_core::Result; +use std::sync::Arc; + +/// Progress callback for index building. +/// +/// Called at stage boundaries during index construction. Stages are sequential: +/// `stage_complete` is always called before the next `stage_start`, so only one +/// stage is active at a time. Stage names are index-type-specific (e.g. +/// "train_ivf", "shuffle", "build_partitions" for vector indices; "load_data", +/// "build_pages" for scalar indices). +/// +/// Methods take `&self` to allow concurrent calls from within a single stage. +/// Implementations must be thread-safe. +#[async_trait] +pub trait IndexBuildProgress: std::fmt::Debug + Sync + Send { + /// A named stage has started. + /// + /// `total` is the number of work units if known, and `unit` describes + /// what is being counted (e.g. "partitions", "batches", "rows"). + async fn stage_start(&self, stage: &str, total: Option, unit: &str) -> Result<()>; + + /// Progress within the current stage. + async fn stage_progress(&self, stage: &str, completed: u64) -> Result<()>; + + /// A named stage has completed. + async fn stage_complete(&self, stage: &str) -> Result<()>; +} + +#[derive(Debug, Clone, Default)] +pub struct NoopIndexBuildProgress; + +#[async_trait] +impl IndexBuildProgress for NoopIndexBuildProgress { + async fn stage_start(&self, _: &str, _: Option, _: &str) -> Result<()> { + Ok(()) + } + async fn stage_progress(&self, _: &str, _: u64) -> Result<()> { + Ok(()) + } + async fn stage_complete(&self, _: &str) -> Result<()> { + Ok(()) + } +} + +/// Helper to create a default noop progress instance. +pub fn noop_progress() -> Arc { + Arc::new(NoopIndexBuildProgress) +} diff --git a/rust/lance-index/src/vector/kmeans.rs b/rust/lance-index/src/vector/kmeans.rs index be76fade6f6..58cfddd3dc3 100644 --- a/rust/lance-index/src/vector/kmeans.rs +++ b/rust/lance-index/src/vector/kmeans.rs @@ -56,7 +56,6 @@ pub enum KMeanInit { } /// KMean Training Parameters -#[derive(Debug)] pub struct KMeansParams { /// Max number of iterations. pub max_iters: u32, @@ -87,6 +86,24 @@ pub struct KMeansParams { /// Higher would split the clusters more aggressively, which would be more accurate but slower. /// hierarchical kmeans is enabled only if hierarchical_k > 1 and k > 256. pub hierarchical_k: usize, + + /// Optional sync callback for iteration progress: (current_iteration, max_iterations). + pub on_progress: Option>, +} + +impl std::fmt::Debug for KMeansParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KMeansParams") + .field("max_iters", &self.max_iters) + .field("tolerance", &self.tolerance) + .field("redos", &self.redos) + .field("init", &self.init) + .field("distance_type", &self.distance_type) + .field("balance_factor", &self.balance_factor) + .field("hierarchical_k", &self.hierarchical_k) + .field("on_progress", &self.on_progress.as_ref().map(|_| "...")) + .finish() + } } impl Default for KMeansParams { @@ -99,6 +116,7 @@ impl Default for KMeansParams { distance_type: DistanceType::L2, balance_factor: 0.0, hierarchical_k: 16, + on_progress: None, } } } @@ -133,6 +151,11 @@ impl KMeansParams { self } + pub fn with_on_progress(mut self, cb: Arc) -> Self { + self.on_progress = Some(cb); + self + } + /// Set the number of clusters to train in each hierarchical level. /// /// Higher would split the clusters more aggressively, which would be more accurate but slower. @@ -663,6 +686,9 @@ impl KMeans { let mut loss = f64::MAX; for i in 1..=params.max_iters { + if let Some(cb) = ¶ms.on_progress { + cb(i, params.max_iters); + } if i % 10 == 0 { info!( "KMeans training: iteration {} / {}, redo={}", diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index 79147721a24..27583c86b4c 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -69,6 +69,7 @@ pub struct IvfShuffler { // options precomputed_shuffle_buffers: Option>, + progress: Arc, } impl IvfShuffler { @@ -78,9 +79,15 @@ impl IvfShuffler { output_dir, num_partitions, precomputed_shuffle_buffers: None, + progress: crate::progress::noop_progress(), } } + pub fn with_progress(mut self, progress: Arc) -> Self { + self.progress = progress; + self + } + pub fn with_precomputed_shuffle_buffers( mut self, precomputed_shuffle_buffers: Option>, @@ -159,6 +166,7 @@ impl Shuffler for IvfShuffler { .buffered(get_num_compute_intensive_cpus()); let mut total_loss = 0.0; + let mut counter: u64 = 0; while let Some(shuffled) = parallel_sort_stream.next().await { let (shuffled, loss) = shuffled?; total_loss += loss; @@ -172,6 +180,9 @@ impl Shuffler for IvfShuffler { } } try_join_all(futs).await?; + + counter += 1; + self.progress.stage_progress("shuffle", counter).await?; } // finish all writers diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index d74829d15f4..0f308e9c1d9 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -28,6 +28,7 @@ use lance_index::frag_reuse::{FragReuseIndex, FRAG_REUSE_INDEX_NAME}; use lance_index::mem_wal::{MemWalIndex, MEM_WAL_INDEX_NAME}; use lance_index::optimize::OptimizeOptions; use lance_index::pb::index::Implementation; +pub use lance_index::progress::{IndexBuildProgress, NoopIndexBuildProgress}; use lance_index::scalar::expression::{ IndexInformationProvider, MultiQueryParser, ScalarQueryParser, }; diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index d929dbffffd..90dba8a6fc5 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -18,6 +18,7 @@ use crate::{ }; use futures::future::BoxFuture; use lance_core::datatypes::format_field_path; +use lance_index::progress::{IndexBuildProgress, NoopIndexBuildProgress}; use lance_index::{ metrics::NoOpMetricsCollector, scalar::{inverted::tokenizer::InvertedIndexParams, ScalarIndexParams, LANCE_SCALAR_INDEX}, @@ -54,6 +55,7 @@ pub struct CreateIndexBuilder<'a> { fragments: Option>, index_uuid: Option, preprocessed_data: Option>, + progress: Arc, } impl<'a> CreateIndexBuilder<'a> { @@ -74,6 +76,7 @@ impl<'a> CreateIndexBuilder<'a> { fragments: None, index_uuid: None, preprocessed_data: None, + progress: Arc::new(NoopIndexBuildProgress), } } @@ -110,6 +113,11 @@ impl<'a> CreateIndexBuilder<'a> { self } + pub fn progress(mut self, p: Arc) -> Self { + self.progress = p; + self + } + #[instrument(skip_all)] pub async fn execute_uncommitted(&mut self) -> Result { if self.columns.len() != 1 { @@ -324,6 +332,7 @@ impl<'a> CreateIndexBuilder<'a> { vec_params, fri, self.fragments.as_ref().unwrap(), + self.progress.clone(), )) .await?; } else { @@ -335,6 +344,7 @@ impl<'a> CreateIndexBuilder<'a> { &index_id.to_string(), vec_params, fri, + self.progress.clone(), )) .await?; } diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 1dd015f789b..22a79117924 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -23,6 +23,7 @@ use lance_file::previous::reader::FileReader as PreviousFileReader; use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; +use lance_index::progress::{noop_progress, IndexBuildProgress}; use lance_index::vector::bq::builder::RabitQuantizer; use lance_index::vector::bq::RQBuildParams; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; @@ -298,6 +299,7 @@ impl IndexParams for VectorIndexParams { } /// Build a Distributed Vector Index for specific fragments +#[allow(clippy::too_many_arguments)] #[instrument(level = "debug", skip(dataset))] pub(crate) async fn build_distributed_vector_index( dataset: &Dataset, @@ -307,6 +309,7 @@ pub(crate) async fn build_distributed_vector_index( params: &VectorIndexParams, frag_reuse_index: Option>, fragment_ids: &[u32], + progress: Arc, ) -> Result<()> { let stages = ¶ms.stages; @@ -373,7 +376,7 @@ for concurrent distributed create_index" let temp_dir = TempStdDir::default(); let temp_dir_path = Path::from_filesystem_path(&temp_dir)?; - let shuffler = IvfShuffler::new(temp_dir_path, num_partitions); + let shuffler = IvfShuffler::new(temp_dir_path, num_partitions).with_progress(progress.clone()); let filtered_dataset = dataset.clone(); @@ -441,6 +444,7 @@ please provide PQBuildParams.codebook for distributed indexing" )? .with_ivf(ivf_model) .with_fragment_filter(fragment_filter) + .with_progress(progress.clone()) .build() .await?; } @@ -461,6 +465,7 @@ please provide PQBuildParams.codebook for distributed indexing" )? .with_ivf(ivf_model) .with_fragment_filter(fragment_filter) + .with_progress(progress.clone()) .build() .await?; } @@ -517,6 +522,7 @@ please provide PQBuildParams.codebook for distributed indexing" // and transpose only after all shards are merged. .with_transpose(false) .with_fragment_filter(fragment_filter) + .with_progress(progress.clone()) .build() .await?; } @@ -548,6 +554,7 @@ please provide PQBuildParams.codebook for distributed indexing" frag_reuse_index, )? .with_fragment_filter(fragment_filter) + .with_progress(progress.clone()) .build() .await?; } @@ -577,6 +584,7 @@ please provide PQBuildParams.codebook for distributed indexing" frag_reuse_index, )? .with_fragment_filter(fragment_filter) + .with_progress(progress.clone()) .build() .await?; } @@ -622,6 +630,7 @@ please provide PQBuildParams.codebook for distributed indexing" // and transpose only after all shards are merged. .with_transpose(false) .with_fragment_filter(fragment_filter) + .with_progress(progress.clone()) .build() .await?; } @@ -660,6 +669,7 @@ please provide PQBuildParams.codebook for distributed indexing" frag_reuse_index, )? .with_fragment_filter(fragment_filter) + .with_progress(progress.clone()) .build() .await?; } @@ -698,6 +708,7 @@ pub(crate) async fn build_vector_index( uuid: &str, params: &VectorIndexParams, frag_reuse_index: Option>, + progress: Arc, ) -> Result<()> { let stages = ¶ms.stages; @@ -741,7 +752,7 @@ pub(crate) async fn build_vector_index( let temp_dir = TempStdDir::default(); let temp_dir_path = Path::from_filesystem_path(&temp_dir)?; - let shuffler = IvfShuffler::new(temp_dir_path, num_partitions); + let shuffler = IvfShuffler::new(temp_dir_path, num_partitions).with_progress(progress.clone()); match index_type { IndexType::IvfFlat => match element_type { DataType::Float16 | DataType::Float32 | DataType::Float64 => { @@ -756,6 +767,7 @@ pub(crate) async fn build_vector_index( (), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -771,6 +783,7 @@ pub(crate) async fn build_vector_index( (), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -800,6 +813,7 @@ pub(crate) async fn build_vector_index( params.metric_type, &ivf_params, pq_params, + progress.clone(), ) .await?; } @@ -815,6 +829,7 @@ pub(crate) async fn build_vector_index( (), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -839,6 +854,7 @@ pub(crate) async fn build_vector_index( (), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -861,6 +877,7 @@ pub(crate) async fn build_vector_index( (), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -882,6 +899,7 @@ pub(crate) async fn build_vector_index( hnsw_params.clone(), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -909,6 +927,7 @@ pub(crate) async fn build_vector_index( hnsw_params.clone(), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -936,6 +955,7 @@ pub(crate) async fn build_vector_index( hnsw_params.clone(), frag_reuse_index, )? + .with_progress(progress.clone()) .build() .await?; } @@ -959,6 +979,7 @@ pub(crate) async fn build_vector_index_incremental( params: &VectorIndexParams, existing_index: Arc, frag_reuse_index: Option>, + progress: Arc, ) -> Result<()> { let stages = ¶ms.stages; @@ -1008,7 +1029,9 @@ pub(crate) async fn build_vector_index_incremental( let temp_dir = TempStdDir::default(); let temp_dir_path = Path::from_filesystem_path(&temp_dir)?; - let shuffler = Box::new(IvfShuffler::new(temp_dir_path, ivf_model.num_partitions())); + let shuffler = Box::new( + IvfShuffler::new(temp_dir_path, ivf_model.num_partitions()).with_progress(progress.clone()), + ); let index_dir = dataset.indices_dir().child(uuid); @@ -1031,6 +1054,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1047,6 +1071,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1071,6 +1096,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1088,6 +1114,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1105,6 +1132,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1134,6 +1162,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1150,6 +1179,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1166,6 +1196,7 @@ pub(crate) async fn build_vector_index_incremental( )? .with_ivf(ivf_model) .with_quantizer(quantizer.try_into()?) + .with_progress(progress.clone()) .build() .await?; } @@ -1542,6 +1573,7 @@ pub async fn initialize_vector_index( ¶ms, source_vector_index, frag_reuse_index, + noop_progress(), ) .await?; @@ -2143,9 +2175,16 @@ mod tests { ..Default::default() }; let dim = utils::get_vector_dim(dataset.schema(), "vector").unwrap(); - let ivf_model = build_ivf_model(&dataset, "vector", dim, MetricType::L2, &ivf_params) - .await - .unwrap(); + let ivf_model = build_ivf_model( + &dataset, + "vector", + dim, + MetricType::L2, + &ivf_params, + noop_progress(), + ) + .await + .unwrap(); // Attach precomputed global centroids to ivf_params for distributed build. ivf_params.centroids = ivf_model.centroids.clone().map(Arc::new); @@ -2160,6 +2199,7 @@ mod tests { ¶ms, None, &[invalid_id], + noop_progress(), ) .await; @@ -2187,9 +2227,16 @@ mod tests { ..Default::default() }; let dim = utils::get_vector_dim(dataset.schema(), "vector").unwrap(); - let ivf_model = build_ivf_model(&dataset, "vector", dim, MetricType::L2, &ivf_params) - .await - .unwrap(); + let ivf_model = build_ivf_model( + &dataset, + "vector", + dim, + MetricType::L2, + &ivf_params, + noop_progress(), + ) + .await + .unwrap(); // Attach precomputed global centroids to ivf_params for distributed build. ivf_params.centroids = ivf_model.centroids.clone().map(Arc::new); @@ -2204,6 +2251,7 @@ mod tests { ¶ms, None, &[], + noop_progress(), ) .await; @@ -2214,6 +2262,86 @@ mod tests { ); } + #[tokio::test] + async fn test_train_ivf_progress_is_emitted_before_completion() { + use std::sync::atomic::{AtomicBool, Ordering}; + + #[derive(Debug)] + struct RecordingProgress { + train_ivf_complete: AtomicBool, + saw_train_ivf_progress_before_complete: AtomicBool, + saw_train_ivf_progress_after_complete: AtomicBool, + } + + #[async_trait::async_trait] + impl IndexBuildProgress for RecordingProgress { + async fn stage_start(&self, _: &str, _: Option, _: &str) -> Result<()> { + Ok(()) + } + + async fn stage_progress(&self, stage: &str, _: u64) -> Result<()> { + if stage == "train_ivf" { + if self.train_ivf_complete.load(Ordering::Relaxed) { + self.saw_train_ivf_progress_after_complete + .store(true, Ordering::Relaxed); + } else { + self.saw_train_ivf_progress_before_complete + .store(true, Ordering::Relaxed); + } + } + Ok(()) + } + + async fn stage_complete(&self, stage: &str) -> Result<()> { + if stage == "train_ivf" { + self.train_ivf_complete.store(true, Ordering::Relaxed); + } + Ok(()) + } + } + + let test_dir = TempStrDir::default(); + let uri = format!("{}/ds", test_dir.as_str()); + let reader = lance_datagen::gen_batch() + .col("id", array::step::()) + .col("vector", array::rand_vec::(32.into())) + .into_reader_rows(RowCount::from(128), BatchCount::from(1)); + let dataset = Dataset::write(reader, &uri, None).await.unwrap(); + + let params = VectorIndexParams::ivf_flat(4, MetricType::L2); + let uuid = Uuid::new_v4().to_string(); + let progress = Arc::new(RecordingProgress { + train_ivf_complete: AtomicBool::new(false), + saw_train_ivf_progress_before_complete: AtomicBool::new(false), + saw_train_ivf_progress_after_complete: AtomicBool::new(false), + }); + + build_vector_index( + &dataset, + "vector", + "vector_ivf_flat_progress", + &uuid, + ¶ms, + None, + progress.clone(), + ) + .await + .unwrap(); + + assert!( + progress + .saw_train_ivf_progress_before_complete + .load(Ordering::Relaxed), + "expected at least one train_ivf progress event before completion" + ); + assert!( + !progress + .saw_train_ivf_progress_after_complete + .load(Ordering::Relaxed), + "found train_ivf progress after completion" + ); + } + #[tokio::test] async fn test_build_distributed_training_metadata_missing() { let test_dir = TempStrDir::default(); @@ -2260,6 +2388,7 @@ mod tests { ¶ms, None, &[valid_id], + noop_progress(), ) .await; diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 59300550608..da6e11eaa53 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -32,6 +32,7 @@ use lance_file::writer::FileWriter; use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; +use lance_index::progress::{IndexBuildProgress, NoopIndexBuildProgress}; use lance_index::vector::bq::storage::{unpack_codes, RABIT_CODE_COLUMN}; use lance_index::vector::kmeans::KMeansParams; use lance_index::vector::pq::storage::transpose; @@ -150,6 +151,8 @@ pub struct IvfIndexBuilder { merged_num: usize, // whether to transpose codes when building storage transpose_codes: bool, + + progress: Arc, } type BuildStream = @@ -192,6 +195,7 @@ impl IvfIndexBuilder optimize_options: None, merged_num: 0, transpose_codes: true, + progress: Arc::new(NoopIndexBuildProgress), }) } @@ -259,27 +263,44 @@ impl IvfIndexBuilder optimize_options: None, merged_num: 0, transpose_codes: true, + progress: Arc::new(NoopIndexBuildProgress), }) } // build the index with the all data in the dataset, // return the number of indices merged pub async fn build(&mut self) -> Result { + let progress = self.progress.clone(); + // step 1. train IVF & quantizer + let max_iters = self.ivf_params.as_ref().map(|p| p.max_iters as u64); + progress + .stage_start("train_ivf", max_iters, "iterations") + .await?; self.with_ivf(self.load_or_build_ivf().await?); + progress.stage_complete("train_ivf").await?; + progress.stage_start("train_quantizer", None, "").await?; self.with_quantizer(self.load_or_build_quantizer().await?); + progress.stage_complete("train_quantizer").await?; // step 2. shuffle the dataset if self.shuffle_reader.is_none() { + progress.stage_start("shuffle", None, "batches").await?; self.shuffle_dataset().await?; + progress.stage_complete("shuffle").await?; } // step 3. build partitions + let num_partitions = self.ivf.as_ref().map(|ivf| ivf.num_partitions() as u64); + progress + .stage_start("build_partitions", num_partitions, "partitions") + .await?; let build_idx_stream = self.build_partitions().boxed().await?; // step 4. merge all partitions self.merge_partitions(build_idx_stream).await?; + progress.stage_complete("build_partitions").await?; Ok(self.merged_num) } @@ -365,6 +386,12 @@ impl IvfIndexBuilder self } + /// Set progress callback for index building + pub fn with_progress(&mut self, progress: Arc) -> &mut Self { + self.progress = progress; + self + } + #[instrument(name = "load_or_build_ivf", level = "debug", skip_all)] async fn load_or_build_ivf(&self) -> Result { match &self.ivf { @@ -381,8 +408,15 @@ impl IvfIndexBuilder "IVF build params not set", location!(), ))?; - super::build_ivf_model(dataset, &self.column, dim, self.distance_type, ivf_params) - .await + super::build_ivf_model( + dataset, + &self.column, + dim, + self.distance_type, + ivf_params, + self.progress.clone(), + ) + .await } } } @@ -1049,9 +1083,11 @@ impl IvfIndexBuilder let mut part_id = 0; let mut total_loss = 0.0; + let progress = self.progress.clone(); log::info!("merging {} partitions", ivf.num_partitions()); while let Some(part) = build_stream.try_next().await? { part_id += 1; + progress.stage_progress("build_partitions", part_id).await?; let Some((storage, index, loss)) = part else { log::warn!("partition {} is empty, skipping", part_id); diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index c14cdeada81..7d6280c468c 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -97,6 +97,7 @@ use serde_json::json; use snafu::location; use std::collections::HashSet; use std::{any::Any, collections::HashMap, sync::Arc}; +use tokio::sync::mpsc; use tracing::instrument; use uuid::Uuid; @@ -1228,6 +1229,7 @@ pub async fn build_ivf_model( dim: usize, metric_type: MetricType, params: &IvfBuildParams, + progress: std::sync::Arc, ) -> Result { let num_partitions = params.num_partitions.unwrap(); let centroids = params.centroids.clone(); @@ -1276,7 +1278,7 @@ pub async fn build_ivf_model( info!("Start to train IVF model"); let start = std::time::Instant::now(); - let ivf = train_ivf_model(centroids, training_data, mt, params).await?; + let ivf = train_ivf_model(centroids, training_data, mt, params, progress).await?; info!( "Trained IVF model in {:02} seconds", start.elapsed().as_secs_f32() @@ -1290,6 +1292,7 @@ async fn build_ivf_model_and_pq( metric_type: MetricType, ivf_params: &IvfBuildParams, pq_params: &PQBuildParams, + progress: std::sync::Arc, ) -> Result<(IvfModel, ProductQuantizer)> { sanity_check_params(ivf_params, pq_params)?; @@ -1306,7 +1309,8 @@ async fn build_ivf_model_and_pq( get_vector_type(dataset.schema(), column)?; let dim = get_vector_dim(dataset.schema(), column)?; - let ivf_model = build_ivf_model(dataset, column, dim, metric_type, ivf_params).await?; + let ivf_model = + build_ivf_model(dataset, column, dim, metric_type, ivf_params, progress).await?; let ivf_residual = if matches!(metric_type, MetricType::Cosine | MetricType::L2) { Some(&ivf_model) @@ -1349,6 +1353,7 @@ pub async fn load_precomputed_partitions_if_available( } } +#[allow(clippy::too_many_arguments)] pub async fn build_ivf_pq_index( dataset: &Dataset, column: &str, @@ -1357,9 +1362,17 @@ pub async fn build_ivf_pq_index( metric_type: MetricType, ivf_params: &IvfBuildParams, pq_params: &PQBuildParams, + progress: std::sync::Arc, ) -> Result<()> { - let (ivf_model, pq) = - build_ivf_model_and_pq(dataset, column, metric_type, ivf_params, pq_params).await?; + let (ivf_model, pq) = build_ivf_model_and_pq( + dataset, + column, + metric_type, + ivf_params, + pq_params, + progress, + ) + .await?; let stream = scan_index_field_stream(dataset, column).await?; let precomputed_partitions = load_precomputed_partitions_if_available(ivf_params).await?; @@ -1393,8 +1406,15 @@ pub async fn build_ivf_hnsw_pq_index( hnsw_params: &HnswBuildParams, pq_params: &PQBuildParams, ) -> Result<()> { - let (ivf_model, pq) = - build_ivf_model_and_pq(dataset, column, metric_type, ivf_params, pq_params).await?; + let (ivf_model, pq) = build_ivf_model_and_pq( + dataset, + column, + metric_type, + ivf_params, + pq_params, + lance_index::progress::noop_progress(), + ) + .await?; let stream = scan_index_field_stream(dataset, column).await?; let precomputed_partitions = load_precomputed_partitions_if_available(ivf_params).await?; @@ -2099,21 +2119,51 @@ async fn do_train_ivf_model( dimension: usize, metric_type: MetricType, params: &IvfBuildParams, + progress: std::sync::Arc, ) -> Result where ::Native: Dot + L2 + Normalize, PrimitiveArray: From>, { const REDOS: usize = 1; + let (progress_tx, mut progress_rx) = mpsc::unbounded_channel::(); + let progress_worker = { + let progress = progress.clone(); + tokio::spawn(async move { + while let Some(iter) = progress_rx.recv().await { + if let Err(e) = progress.stage_progress("train_ivf", iter).await { + warn!("Progress callback error during train_ivf: {e}"); + } + } + }) + }; + + let on_progress: Arc = { + let progress_tx = progress_tx.clone(); + let cumulative_iters = std::sync::atomic::AtomicU64::new(0); + Arc::new(move |_iter: u32, _max_iters: u32| { + // Track cumulative iterations across all kmeans runs in this stage + // (flat and hierarchical both invoke the callback per-iteration). + let total = cumulative_iters.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + // Non-blocking send from sync kmeans loop into async progress worker. + let _ = progress_tx.send(total); + }) + }; let kmeans_params = KMeansParams::new(centroids, params.max_iters as u32, REDOS, metric_type) - .with_balance_factor(1.0); + .with_balance_factor(1.0) + .with_on_progress(on_progress); let kmeans = lance_index::vector::kmeans::train_kmeans::( data, kmeans_params, dimension, params.num_partitions.unwrap_or(32), params.sample_rate, - )?; + ); + drop(progress_tx); + if let Err(e) = progress_worker.await { + warn!("Progress worker join error during train_ivf: {e}"); + } + let kmeans = kmeans?; Ok(IvfModel::new( FixedSizeListArray::try_new_from_values(kmeans.centroids, dimension as i32)?, Some(kmeans.loss), @@ -2126,6 +2176,7 @@ async fn train_ivf_model( data: &FixedSizeListArray, distance_type: DistanceType, params: &IvfBuildParams, + progress: std::sync::Arc, ) -> Result { assert!( distance_type != DistanceType::Cosine, @@ -2141,6 +2192,7 @@ async fn train_ivf_model( dim, distance_type, params, + progress.clone(), ) .await } @@ -2151,6 +2203,7 @@ async fn train_ivf_model( dim, distance_type, params, + progress.clone(), ) .await } @@ -2161,6 +2214,7 @@ async fn train_ivf_model( dim, distance_type, params, + progress.clone(), ) .await } @@ -2175,6 +2229,7 @@ async fn train_ivf_model( dim, distance_type, params, + progress.clone(), ) .await } @@ -2185,6 +2240,7 @@ async fn train_ivf_model( dim, distance_type, params, + progress.clone(), ) .await } @@ -2609,6 +2665,7 @@ mod tests { MetricType::L2, &ivf_params, &pq_params, + lance_index::progress::noop_progress(), ) .await .unwrap(); @@ -3059,9 +3116,16 @@ mod tests { let (dataset, _) = generate_test_dataset(test_uri, 1000.0..1100.0).await; let ivf_params = IvfBuildParams::new(2); - let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::L2, &ivf_params) - .await - .unwrap(); + let ivf_model = build_ivf_model( + &dataset, + "vector", + DIM, + MetricType::L2, + &ivf_params, + lance_index::progress::noop_progress(), + ) + .await + .unwrap(); assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len()); assert_eq!(32, ivf_model.centroids.as_ref().unwrap().value_length()); assert_eq!(2, ivf_model.num_partitions()); @@ -3087,9 +3151,16 @@ mod tests { let (dataset, _) = generate_test_dataset(test_uri, 1000.0..1100.0).await; let ivf_params = IvfBuildParams::new(2); - let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::Cosine, &ivf_params) - .await - .unwrap(); + let ivf_model = build_ivf_model( + &dataset, + "vector", + DIM, + MetricType::Cosine, + &ivf_params, + lance_index::progress::noop_progress(), + ) + .await + .unwrap(); assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len()); assert_eq!(32, ivf_model.centroids.as_ref().unwrap().value_length()); assert_eq!(2, ivf_model.num_partitions()); @@ -3646,6 +3717,76 @@ mod tests { assert!(object_store.exists(&keep_root_file).await.unwrap()); } + #[tokio::test(flavor = "multi_thread")] + async fn test_build_ivf_model_progress_callback() { + use lance_index::progress::IndexBuildProgress; + use tokio::sync::Mutex; + + #[derive(Debug)] + struct RecordingProgress { + calls: Arc>>, + } + + #[async_trait::async_trait] + impl IndexBuildProgress for RecordingProgress { + async fn stage_start(&self, _: &str, _: Option, _: &str) -> Result<()> { + Ok(()) + } + async fn stage_progress(&self, stage: &str, completed: u64) -> Result<()> { + self.calls.lock().await.push((stage.to_string(), completed)); + Ok(()) + } + async fn stage_complete(&self, _: &str) -> Result<()> { + Ok(()) + } + } + + let test_dir = TempStrDir::default(); + let test_uri = test_dir.as_str(); + + let (dataset, _) = generate_test_dataset(test_uri, 1000.0..1100.0).await; + + let ivf_params = IvfBuildParams::new(2); + let calls: Arc>> = Arc::new(Mutex::new(Vec::new())); + let progress: Arc = Arc::new(RecordingProgress { + calls: calls.clone(), + }); + + let ivf_model = build_ivf_model( + &dataset, + "vector", + DIM, + MetricType::L2, + &ivf_params, + progress, + ) + .await + .unwrap(); + assert_eq!(2, ivf_model.num_partitions()); + + // Let spawned progress tasks complete. + tokio::task::yield_now().await; + + let recorded = calls.lock().await; + assert!( + !recorded.is_empty(), + "Expected progress callbacks to be called" + ); + // All calls should be for train_ivf stage + for (stage, _) in recorded.iter() { + assert_eq!(stage, "train_ivf"); + } + // Completed values should be monotonically increasing + for window in recorded.windows(2) { + assert!( + window[1].1 >= window[0].1, + "Expected monotonically increasing progress: {} >= {}", + window[1].1, + window[0].1, + ); + } + } + #[tokio::test] async fn test_cleanup_idempotent() { let object_store = ObjectStore::memory(); diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 1b7a6b43076..08631706951 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -713,9 +713,16 @@ mod tests { let (dataset, vectors) = generate_dataset(test_uri, 100.0..120.0).await; let ivf_params = IvfBuildParams::new(4); - let ivf = build_ivf_model(&dataset, "vector", DIM, MetricType::Cosine, &ivf_params) - .await - .unwrap(); + let ivf = build_ivf_model( + &dataset, + "vector", + DIM, + MetricType::Cosine, + &ivf_params, + lance_index::progress::noop_progress(), + ) + .await + .unwrap(); let params = PQBuildParams::new(16, 8); let pq = build_pq_model( &dataset,