Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions rust/lance-index/src/scalar/inverted/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl InvertedIndexBuilder {
let fragment_mask = self.fragment_mask;
let token_set_format = self.token_set_format;
let tokenized_count = tokenized_count.clone();
let task = tokio::task::spawn(async move {
index_tasks.push(tokio::task::spawn(async move {
let mut worker = IndexWorker::new(
store,
tokenizer,
Expand All @@ -202,32 +202,29 @@ impl InvertedIndexBuilder {
.stage_progress("tokenize_docs", tokenized_count)
.await?;
}
let partitions = worker.finish().await?;
Result::Ok(partitions)
});
index_tasks.push(task);
worker.finish().await
}));
}
// Keep the channel lifetime tied to the worker tasks so senders observe
// worker exits instead of blocking on an orphaned receiver handle.
drop(receiver);

let sender = Arc::new(sender);

let mut stream = Box::pin(stream.then({
|batch_result| {
let sender = sender.clone();
async move {
let sender = sender.clone();
let batch = batch_result?;
let num_rows = batch.num_rows();
sender.send(batch).await.expect("failed to send batch");
Result::Ok(num_rows)
}
}
}));
let mut stream = Box::pin(stream);
log::info!("indexing FTS with {} workers", num_workers);

let mut last_num_rows = 0;
let mut total_num_rows = 0;
let start = std::time::Instant::now();
while let Some(num_rows) = stream.try_next().await? {
while let Some(batch) = stream.try_next().await? {
let num_rows = batch.num_rows();

if sender.send(batch).await.is_err() {
// this only happens if all workers have existed,
// so we don't return the send error here,
// avoiding hiding the real error from workers.
break;
}

total_num_rows += num_rows;
if total_num_rows >= last_num_rows + 1_000_000 {
log::debug!(
Expand All @@ -241,7 +238,6 @@ impl InvertedIndexBuilder {
}
// drop the sender to stop receivers
drop(stream);
debug_assert_eq!(sender.sender_count(), 1);
drop(sender);
log::info!("dispatching elapsed: {:?}", start.elapsed());

Expand Down Expand Up @@ -1302,6 +1298,7 @@ mod tests {
use lance_core::utils::tempfile::TempDir;
use std::any::Any;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::Mutex;

fn make_doc_batch(doc: &str, row_id: u64) -> RecordBatch {
Expand Down Expand Up @@ -1583,6 +1580,24 @@ mod tests {
}
}

#[derive(Debug, Default)]
struct FailingProgress;

#[async_trait]
impl IndexBuildProgress for FailingProgress {
async fn stage_start(&self, _stage: &str, _total: Option<u64>, _unit: &str) -> Result<()> {
Ok(())
}

async fn stage_progress(&self, _stage: &str, _completed: u64) -> Result<()> {
Err(Error::io("injected progress failure"))
}

async fn stage_complete(&self, _stage: &str) -> Result<()> {
Ok(())
}
}

#[tokio::test]
async fn test_builder_reports_progress_stages() -> Result<()> {
let index_dir = TempDir::default();
Expand Down Expand Up @@ -1682,4 +1697,29 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_update_index_returns_worker_error_when_workers_exit_during_dispatch() {
let num_batches = (*LANCE_FTS_NUM_SHARDS * 2 + 1) as u64;
let schema = make_doc_batch("hello world", 0).schema();
let stream = RecordBatchStreamAdapter::new(
schema,
stream::iter((0..num_batches).map(|row_id| Ok(make_doc_batch("hello world", row_id)))),
);
let stream = Box::pin(stream);

let mut builder =
InvertedIndexBuilder::new(InvertedIndexParams::default().skip_merge(true))
.with_progress(Arc::new(FailingProgress));

let result = tokio::time::timeout(Duration::from_secs(5), builder.update_index(stream))
.await
.expect("update_index should not hang")
.expect_err("worker failure should be returned");

assert!(
result.to_string().contains("injected progress failure"),
"unexpected error: {result}"
);
}
}