From 0df4c8ec82eabee34b8ca29ee382749a17e29fc4 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Sun, 8 Mar 2026 16:02:59 +0800 Subject: [PATCH] Refactor builder worker cleanup --- .../src/scalar/inverted/builder.rs | 82 ++++++++++++++----- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 2b352c46bc4..b52a656628e 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -182,7 +182,7 @@ impl InvertedIndexBuilder { let fragment_mask = self.fragment_mask; let token_set_format = self.token_set_format; let tokenized_count = tokenized_count.clone(); - let task = tokio::task::spawn(async move { + index_tasks.push(tokio::task::spawn(async move { let mut worker = IndexWorker::new( store, tokenizer, @@ -202,32 +202,29 @@ impl InvertedIndexBuilder { .stage_progress("tokenize_docs", tokenized_count) .await?; } - let partitions = worker.finish().await?; - Result::Ok(partitions) - }); - index_tasks.push(task); + worker.finish().await + })); } + // Keep the channel lifetime tied to the worker tasks so senders observe + // worker exits instead of blocking on an orphaned receiver handle. + drop(receiver); - let sender = Arc::new(sender); - - let mut stream = Box::pin(stream.then({ - |batch_result| { - let sender = sender.clone(); - async move { - let sender = sender.clone(); - let batch = batch_result?; - let num_rows = batch.num_rows(); - sender.send(batch).await.expect("failed to send batch"); - Result::Ok(num_rows) - } - } - })); + let mut stream = Box::pin(stream); log::info!("indexing FTS with {} workers", num_workers); let mut last_num_rows = 0; let mut total_num_rows = 0; let start = std::time::Instant::now(); - while let Some(num_rows) = stream.try_next().await? { + while let Some(batch) = stream.try_next().await? { + let num_rows = batch.num_rows(); + + if sender.send(batch).await.is_err() { + // this only happens if all workers have existed, + // so we don't return the send error here, + // avoiding hiding the real error from workers. + break; + } + total_num_rows += num_rows; if total_num_rows >= last_num_rows + 1_000_000 { log::debug!( @@ -241,7 +238,6 @@ impl InvertedIndexBuilder { } // drop the sender to stop receivers drop(stream); - debug_assert_eq!(sender.sender_count(), 1); drop(sender); log::info!("dispatching elapsed: {:?}", start.elapsed()); @@ -1302,6 +1298,7 @@ mod tests { use lance_core::utils::tempfile::TempDir; use std::any::Any; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; + use std::time::Duration; use tokio::sync::Mutex; fn make_doc_batch(doc: &str, row_id: u64) -> RecordBatch { @@ -1583,6 +1580,24 @@ mod tests { } } + #[derive(Debug, Default)] + struct FailingProgress; + + #[async_trait] + impl IndexBuildProgress for FailingProgress { + async fn stage_start(&self, _stage: &str, _total: Option, _unit: &str) -> Result<()> { + Ok(()) + } + + async fn stage_progress(&self, _stage: &str, _completed: u64) -> Result<()> { + Err(Error::io("injected progress failure")) + } + + async fn stage_complete(&self, _stage: &str) -> Result<()> { + Ok(()) + } + } + #[tokio::test] async fn test_builder_reports_progress_stages() -> Result<()> { let index_dir = TempDir::default(); @@ -1682,4 +1697,29 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_update_index_returns_worker_error_when_workers_exit_during_dispatch() { + let num_batches = (*LANCE_FTS_NUM_SHARDS * 2 + 1) as u64; + let schema = make_doc_batch("hello world", 0).schema(); + let stream = RecordBatchStreamAdapter::new( + schema, + stream::iter((0..num_batches).map(|row_id| Ok(make_doc_batch("hello world", row_id)))), + ); + let stream = Box::pin(stream); + + let mut builder = + InvertedIndexBuilder::new(InvertedIndexParams::default().skip_merge(true)) + .with_progress(Arc::new(FailingProgress)); + + let result = tokio::time::timeout(Duration::from_secs(5), builder.update_index(stream)) + .await + .expect("update_index should not hang") + .expect_err("worker failure should be returned"); + + assert!( + result.to_string().contains("injected progress failure"), + "unexpected error: {result}" + ); + } }