diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index 8e02d2964c7..48150db4354 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -63,6 +63,8 @@ pub enum Error { Internal { message: String, location: Location }, #[snafu(display("A prerequisite task failed: {message}, {location}"))] PrerequisiteFailed { message: String, location: Location }, + #[snafu(display("Unprocessable: {message}, {location}"))] + Unprocessable { message: String, location: Location }, #[snafu(display("LanceError(Arrow): {message}, {location}"))] Arrow { message: String, location: Location }, #[snafu(display("LanceError(Schema): {message}, {location}"))] diff --git a/rust/lance-index/src/vector/kmeans.rs b/rust/lance-index/src/vector/kmeans.rs index 48c61bcdbe4..be76fade6f6 100644 --- a/rust/lance-index/src/vector/kmeans.rs +++ b/rust/lance-index/src/vector/kmeans.rs @@ -1319,9 +1319,12 @@ where { let num_rows = array.len() / dimension; if num_rows < k { - return Err(Error::Index{message: format!( - "KMeans: can not train {k} centroids with {num_rows} vectors, choose a smaller K (< {num_rows}) instead" - ),location: location!()}); + return Err(Error::Unprocessable { + message: format!( + "KMeans cannot train {k} centroids with {num_rows} vectors; choose a smaller K (< {num_rows})" + ), + location: location!(), + }); } // Only sample sample_rate * num_clusters. See Faiss diff --git a/rust/lance-index/src/vector/pq/builder.rs b/rust/lance-index/src/vector/pq/builder.rs index 5b17e2f1224..d44d86e4f31 100644 --- a/rust/lance-index/src/vector/pq/builder.rs +++ b/rust/lance-index/src/vector/pq/builder.rs @@ -171,10 +171,9 @@ impl PQBuildParams { let num_centroids = 2_usize.pow(self.num_bits as u32); if data.len() < num_centroids { - return Err(Error::Index { + return Err(Error::Unprocessable { message: format!( - "Not enough rows to train PQ. Requires {:?} rows but only {:?} available", - num_centroids, + "Not enough rows to train PQ. Requires {num_centroids} rows but only {} available", data.len() ), location: location!(), diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 60342a00d93..2a3ad508483 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -2491,8 +2491,8 @@ mod tests { .create_index(&["vector"], IndexType::Vector, None, ¶ms, false) .await; - assert!(matches!(result, Err(Error::Index { .. }))); - if let Error::Index { message, .. } = result.unwrap_err() { + assert!(matches!(result, Err(Error::Unprocessable { .. }))); + if let Error::Unprocessable { message, .. } = result.unwrap_err() { assert_eq!( message, "Not enough rows to train PQ. Requires 256 rows but only 100 available", diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index c6167876455..6c55f50f7af 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -508,6 +508,8 @@ pub async fn build_pq_model( params: &PQBuildParams, ivf: Option<&IvfModel>, ) -> Result { + let num_codes = 2_usize.pow(params.num_bits as u32); + if let Some(codebook) = ¶ms.codebook { let dt = if metric_type == MetricType::Cosine { info!("Normalize training data for PQ training: Cosine"); @@ -577,13 +579,16 @@ pub async fn build_pq_model( training_data }; - let num_codes = 2_usize.pow(params.num_bits as u32); if training_data.len() < num_codes { - return Err(Error::Index { + warn!( + "Skip PQ training: only {} rows available, needs >= {}", + training_data.len(), + num_codes + ); + return Err(Error::Unprocessable { message: format!( - "Not enough rows to train PQ. Requires {:?} rows but only {:?} available", - num_codes, - training_data.len() + "Not enough rows to train PQ. Requires {num_codes} rows but only {available} available", + available = training_data.len() ), location: location!(), }); @@ -637,7 +642,9 @@ mod tests { use crate::index::vector::ivf::build_ivf_model; use lance_core::utils::mask::RowIdMask; use lance_index::vector::ivf::IvfBuildParams; - use lance_testing::datagen::generate_random_array_with_range; + use lance_testing::datagen::{ + generate_random_array_with_range, generate_random_array_with_seed, + }; const DIM: usize = 128; async fn generate_dataset( @@ -761,6 +768,35 @@ mod tests { ); } + #[tokio::test] + async fn test_build_pq_model_insufficient_rows_returns_prereq() { + let test_dir = TempStrDir::default(); + let test_uri = test_dir.as_str(); + + let dim = 16; + let schema = Arc::new(Schema::new(vec![Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + dim as i32, + ), + false, + )])); + + let vectors = generate_random_array_with_seed::(dim * 10, [11u8; 32]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, dim as i32).unwrap(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(fsl)]).unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let dataset = Dataset::write(reader, test_uri, None).await.unwrap(); + + let params = PQBuildParams::new(16, 8); + let err = build_pq_model(&dataset, "vector", dim, MetricType::L2, ¶ms, None) + .await + .unwrap_err(); + + assert!(matches!(err, Error::Unprocessable { .. })); + } + struct TestPreFilter { row_ids: Vec, }