Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions java/lance-jni/src/vector_trainer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::sync::Arc;

use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET};
use crate::error::{Error, Result};
use crate::ffi::JNIEnvExt;
Expand All @@ -11,6 +13,7 @@ use jni::objects::{JClass, JFloatArray, JObject, JString};
use jni::sys::jfloatArray;
use jni::JNIEnv;
use lance::index::vector::utils::get_vector_dim;
use lance::index::NoopIndexBuildProgress;
use lance_index::vector::ivf::builder::IvfBuildParams as RustIvfBuildParams;
use lance_index::vector::pq::builder::PQBuildParams as RustPQBuildParams;
use lance_linalg::distance::MetricType;
Expand Down Expand Up @@ -114,6 +117,7 @@ fn inner_train_ivf_centroids<'local>(
dim,
metric_type,
&ivf_params,
Arc::new(NoopIndexBuildProgress),
))?;

let centroids = ivf_model
Expand Down
3 changes: 3 additions & 0 deletions python/src/indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::collections::HashSet;
use std::sync::Arc;

use arrow::pyarrow::{PyArrowType, ToPyArrow};
use arrow_array::{Array, FixedSizeListArray};
Expand All @@ -10,6 +11,7 @@ use chrono::{DateTime, Utc};
use lance::dataset::Dataset as LanceDataset;
use lance::index::vector::ivf::builder::write_vector_storage;
use lance::io::ObjectStore;
use lance_index::progress::NoopIndexBuildProgress;
use lance_index::vector::ivf::shuffler::{shuffle_vectors, IvfShuffler};
use lance_index::vector::{
ivf::{storage::IvfModel, IvfBuildParams},
Expand Down Expand Up @@ -141,6 +143,7 @@ async fn do_train_ivf_model(
dimension,
distance_type,
&params,
Arc::new(NoopIndexBuildProgress),
)
.await
.infer_error()?;
Expand Down
1 change: 1 addition & 0 deletions rust/lance-index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod mem_wal;
pub mod metrics;
pub mod optimize;
pub mod prefilter;
pub mod progress;
pub mod registry;
pub mod scalar;
pub mod traits;
Expand Down
52 changes: 52 additions & 0 deletions rust/lance-index/src/progress.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use async_trait::async_trait;
use lance_core::Result;
use std::sync::Arc;

/// Progress callback for index building.
///
/// Called at stage boundaries during index construction. Stages are sequential:
/// `stage_complete` is always called before the next `stage_start`, so only one
/// stage is active at a time. Stage names are index-type-specific (e.g.
/// "train_ivf", "shuffle", "build_partitions" for vector indices; "load_data",
/// "build_pages" for scalar indices).
///
/// Methods take `&self` to allow concurrent calls from within a single stage.
/// Implementations must be thread-safe.
#[async_trait]
pub trait IndexBuildProgress: std::fmt::Debug + Sync + Send {
/// A named stage has started.
///
/// `total` is the number of work units if known, and `unit` describes
/// what is being counted (e.g. "partitions", "batches", "rows").
async fn stage_start(&self, stage: &str, total: Option<u64>, unit: &str) -> Result<()>;

/// Progress within the current stage.
async fn stage_progress(&self, stage: &str, completed: u64) -> Result<()>;

/// A named stage has completed.
async fn stage_complete(&self, stage: &str) -> Result<()>;
}

#[derive(Debug, Clone, Default)]
pub struct NoopIndexBuildProgress;

#[async_trait]
impl IndexBuildProgress for NoopIndexBuildProgress {
async fn stage_start(&self, _: &str, _: Option<u64>, _: &str) -> Result<()> {
Ok(())
}
async fn stage_progress(&self, _: &str, _: u64) -> Result<()> {
Ok(())
}
async fn stage_complete(&self, _: &str) -> Result<()> {
Ok(())
}
}

/// Helper to create a default noop progress instance.
pub fn noop_progress() -> Arc<dyn IndexBuildProgress> {
Arc::new(NoopIndexBuildProgress)
}
28 changes: 27 additions & 1 deletion rust/lance-index/src/vector/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ pub enum KMeanInit {
}

/// KMean Training Parameters
#[derive(Debug)]
pub struct KMeansParams {
/// Max number of iterations.
pub max_iters: u32,
Expand Down Expand Up @@ -87,6 +86,24 @@ pub struct KMeansParams {
/// Higher would split the clusters more aggressively, which would be more accurate but slower.
/// hierarchical kmeans is enabled only if hierarchical_k > 1 and k > 256.
pub hierarchical_k: usize,

/// Optional sync callback for iteration progress: (current_iteration, max_iterations).
pub on_progress: Option<Arc<dyn Fn(u32, u32) + Send + Sync>>,
}

impl std::fmt::Debug for KMeansParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KMeansParams")
.field("max_iters", &self.max_iters)
.field("tolerance", &self.tolerance)
.field("redos", &self.redos)
.field("init", &self.init)
.field("distance_type", &self.distance_type)
.field("balance_factor", &self.balance_factor)
.field("hierarchical_k", &self.hierarchical_k)
.field("on_progress", &self.on_progress.as_ref().map(|_| "..."))
.finish()
}
}

impl Default for KMeansParams {
Expand All @@ -99,6 +116,7 @@ impl Default for KMeansParams {
distance_type: DistanceType::L2,
balance_factor: 0.0,
hierarchical_k: 16,
on_progress: None,
}
}
}
Expand Down Expand Up @@ -133,6 +151,11 @@ impl KMeansParams {
self
}

pub fn with_on_progress(mut self, cb: Arc<dyn Fn(u32, u32) + Send + Sync>) -> Self {
self.on_progress = Some(cb);
self
}

/// Set the number of clusters to train in each hierarchical level.
///
/// Higher would split the clusters more aggressively, which would be more accurate but slower.
Expand Down Expand Up @@ -663,6 +686,9 @@ impl KMeans {

let mut loss = f64::MAX;
for i in 1..=params.max_iters {
if let Some(cb) = &params.on_progress {
cb(i, params.max_iters);
}
if i % 10 == 0 {
info!(
"KMeans training: iteration {} / {}, redo={}",
Expand Down
11 changes: 11 additions & 0 deletions rust/lance-index/src/vector/v3/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub struct IvfShuffler {

// options
precomputed_shuffle_buffers: Option<Vec<String>>,
progress: Arc<dyn crate::progress::IndexBuildProgress>,
}

impl IvfShuffler {
Expand All @@ -78,9 +79,15 @@ impl IvfShuffler {
output_dir,
num_partitions,
precomputed_shuffle_buffers: None,
progress: crate::progress::noop_progress(),
}
}

pub fn with_progress(mut self, progress: Arc<dyn crate::progress::IndexBuildProgress>) -> Self {
self.progress = progress;
self
}

pub fn with_precomputed_shuffle_buffers(
mut self,
precomputed_shuffle_buffers: Option<Vec<String>>,
Expand Down Expand Up @@ -159,6 +166,7 @@ impl Shuffler for IvfShuffler {
.buffered(get_num_compute_intensive_cpus());

let mut total_loss = 0.0;
let mut counter: u64 = 0;
while let Some(shuffled) = parallel_sort_stream.next().await {
let (shuffled, loss) = shuffled?;
total_loss += loss;
Expand All @@ -172,6 +180,9 @@ impl Shuffler for IvfShuffler {
}
}
try_join_all(futs).await?;

counter += 1;
self.progress.stage_progress("shuffle", counter).await?;
}

// finish all writers
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use lance_index::frag_reuse::{FragReuseIndex, FRAG_REUSE_INDEX_NAME};
use lance_index::mem_wal::{MemWalIndex, MEM_WAL_INDEX_NAME};
use lance_index::optimize::OptimizeOptions;
use lance_index::pb::index::Implementation;
pub use lance_index::progress::{IndexBuildProgress, NoopIndexBuildProgress};
use lance_index::scalar::expression::{
IndexInformationProvider, MultiQueryParser, ScalarQueryParser,
};
Expand Down
10 changes: 10 additions & 0 deletions rust/lance/src/index/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
};
use futures::future::BoxFuture;
use lance_core::datatypes::format_field_path;
use lance_index::progress::{IndexBuildProgress, NoopIndexBuildProgress};
use lance_index::{
metrics::NoOpMetricsCollector,
scalar::{inverted::tokenizer::InvertedIndexParams, ScalarIndexParams, LANCE_SCALAR_INDEX},
Expand Down Expand Up @@ -54,6 +55,7 @@ pub struct CreateIndexBuilder<'a> {
fragments: Option<Vec<u32>>,
index_uuid: Option<String>,
preprocessed_data: Option<Box<dyn RecordBatchReader + Send + 'static>>,
progress: Arc<dyn IndexBuildProgress>,
}

impl<'a> CreateIndexBuilder<'a> {
Expand All @@ -74,6 +76,7 @@ impl<'a> CreateIndexBuilder<'a> {
fragments: None,
index_uuid: None,
preprocessed_data: None,
progress: Arc::new(NoopIndexBuildProgress),
}
}

Expand Down Expand Up @@ -110,6 +113,11 @@ impl<'a> CreateIndexBuilder<'a> {
self
}

pub fn progress(mut self, p: Arc<dyn IndexBuildProgress>) -> Self {
self.progress = p;
self
}

#[instrument(skip_all)]
pub async fn execute_uncommitted(&mut self) -> Result<IndexMetadata> {
if self.columns.len() != 1 {
Expand Down Expand Up @@ -324,6 +332,7 @@ impl<'a> CreateIndexBuilder<'a> {
vec_params,
fri,
self.fragments.as_ref().unwrap(),
self.progress.clone(),
))
.await?;
} else {
Expand All @@ -335,6 +344,7 @@ impl<'a> CreateIndexBuilder<'a> {
&index_id.to_string(),
vec_params,
fri,
self.progress.clone(),
))
.await?;
}
Expand Down
Loading