diff --git a/rust/lance-namespace-impls/src/dir.rs b/rust/lance-namespace-impls/src/dir.rs index 734cbc64c2b..afd9267dffc 100644 --- a/rust/lance-namespace-impls/src/dir.rs +++ b/rust/lance-namespace-impls/src/dir.rs @@ -15,6 +15,8 @@ use async_trait::async_trait; use bytes::Bytes; use futures::TryStreamExt; use lance::dataset::builder::DatasetBuilder; +use lance::dataset::scanner::Scanner; +use lance::dataset::statistics::DatasetStatisticsExt; use lance::dataset::transaction::{Operation, Transaction}; use lance::dataset::{Dataset, WriteMode, WriteParams}; use lance::index::{DatasetIndexExt, IndexParams, vector::VectorIndexParams}; @@ -38,23 +40,27 @@ use std::sync::{Arc, Mutex}; use crate::context::DynamicContextProvider; use lance_namespace::models::{ - BatchDeleteTableVersionsRequest, BatchDeleteTableVersionsResponse, CountTableRowsRequest, - CreateNamespaceRequest, CreateNamespaceResponse, CreateTableIndexRequest, - CreateTableIndexResponse, CreateTableRequest, CreateTableResponse, - CreateTableScalarIndexResponse, CreateTableVersionRequest, CreateTableVersionResponse, - DeclareTableRequest, DeclareTableResponse, DescribeNamespaceRequest, DescribeNamespaceResponse, - DescribeTableIndexStatsRequest, DescribeTableIndexStatsResponse, DescribeTableRequest, - DescribeTableResponse, DescribeTableVersionRequest, DescribeTableVersionResponse, - DescribeTransactionRequest, DescribeTransactionResponse, DropNamespaceRequest, - DropNamespaceResponse, DropTableIndexRequest, DropTableIndexResponse, DropTableRequest, - DropTableResponse, Identity, IndexContent, InsertIntoTableRequest, InsertIntoTableResponse, + AnalyzeTableQueryPlanRequest, BatchDeleteTableVersionsRequest, + BatchDeleteTableVersionsResponse, CountTableRowsRequest, CreateNamespaceRequest, + CreateNamespaceResponse, CreateTableIndexRequest, CreateTableIndexResponse, CreateTableRequest, + CreateTableResponse, CreateTableScalarIndexResponse, CreateTableVersionRequest, + CreateTableVersionResponse, DeclareTableRequest, DeclareTableResponse, + DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableIndexStatsRequest, + DescribeTableIndexStatsResponse, DescribeTableRequest, DescribeTableResponse, + DescribeTableVersionRequest, DescribeTableVersionResponse, DescribeTransactionRequest, + DescribeTransactionResponse, DropNamespaceRequest, DropNamespaceResponse, + DropTableIndexRequest, DropTableIndexResponse, DropTableRequest, DropTableResponse, + ExplainTableQueryPlanRequest, FragmentStats, FragmentSummary, GetTableStatsRequest, + GetTableStatsResponse, Identity, IndexContent, InsertIntoTableRequest, InsertIntoTableResponse, ListNamespacesRequest, ListNamespacesResponse, ListTableIndicesRequest, ListTableIndicesResponse, ListTableVersionsRequest, ListTableVersionsResponse, ListTablesRequest, ListTablesResponse, NamespaceExistsRequest, QueryTableRequest, - TableExistsRequest, TableVersion, + QueryTableRequestColumns, QueryTableRequestVector, RestoreTableRequest, RestoreTableResponse, + TableExistsRequest, TableVersion, UpdateTableSchemaMetadataRequest, + UpdateTableSchemaMetadataResponse, }; -use lance_core::Result; +use lance_core::{Error, Result}; use lance_namespace::LanceNamespace; use lance_namespace::error::NamespaceError; use lance_namespace::schema::arrow_schema_to_json; @@ -1793,6 +1799,142 @@ impl DirectoryNamespace { Ok(deleted_count) } + /// Apply all query parameters from a `QueryTableRequest`-like source onto a `Scanner`. + /// + /// This covers vector search, filters, column projection, limits, and ANN tuning knobs so + /// that `explain_table_query_plan` / `analyze_table_query_plan` produce an accurate plan. + #[allow(clippy::too_many_arguments)] + fn apply_query_params_to_scanner( + scanner: &mut Scanner, + filter: Option<&str>, + columns: Option<&QueryTableRequestColumns>, + vector_column: Option<&str>, + vector: &QueryTableRequestVector, + k: i32, + offset: Option, + prefilter: Option, + bypass_vector_index: Option, + nprobes: Option, + ef: Option, + refine_factor: Option, + distance_type: Option<&str>, + fast_search_flag: Option, + with_row_id: Option, + lower_bound: Option, + upper_bound: Option, + operation: &str, + ) -> Result<()> { + // prefilter must be set before nearest() so the fragment-scan guard sees it. + if let Some(pf) = prefilter { + scanner.prefilter(pf); + } + + if let Some(filter) = filter { + scanner.filter(filter).map_err(|e| { + Error::invalid_input_source( + format!("Invalid filter expression for {}: {}", operation, e).into(), + ) + })?; + } + + if let Some(cols) = columns { + if let Some(ref names) = cols.column_names { + scanner.project(names.as_slice()).map_err(|e| { + Error::invalid_input_source( + format!("Invalid column projection for {}: {}", operation, e).into(), + ) + })?; + } else if let Some(ref aliases) = cols.column_aliases { + // aliases maps output_alias -> source_column + let pairs: Vec<(&str, &str)> = aliases + .iter() + .map(|(alias, src)| (alias.as_str(), src.as_str())) + .collect(); + scanner.project_with_transform(&pairs).map_err(|e| { + Error::invalid_input_source( + format!("Invalid column aliases for {}: {}", operation, e).into(), + ) + })?; + } + } + + // Resolve query vector: prefer single_vector, fall back to first row of multi_vector. + let query_vec: Option> = vector + .single_vector + .as_ref() + .filter(|v| !v.is_empty()) + .cloned() + .or_else(|| { + vector + .multi_vector + .as_ref() + .and_then(|mv| mv.first()) + .filter(|v| !v.is_empty()) + .cloned() + }); + + if let Some(q_vec) = query_vec { + let col = vector_column.unwrap_or("vector"); + let q = Arc::new(Float32Array::from(q_vec)); + scanner + .nearest(col, q.as_ref(), k.max(1) as usize) + .map_err(|e| { + Error::invalid_input_source( + format!("Invalid vector query for {}: {}", operation, e).into(), + ) + })?; + + // ANN parameters — must be applied after nearest(). + if let Some(n) = nprobes { + scanner.nprobes(n.max(1) as usize); + } + if let Some(e) = ef { + scanner.ef(e.max(1) as usize); + } + if let Some(rf) = refine_factor { + scanner.refine(rf.max(0) as u32); + } + // bypass_vector_index and fast_search are mutually exclusive; apply in order. + if let Some(true) = bypass_vector_index { + scanner.use_index(false); + } + if let Some(true) = fast_search_flag { + scanner.fast_search(); + } + if lower_bound.is_some() || upper_bound.is_some() { + scanner.distance_range(lower_bound, upper_bound); + } + if let Some(dt) = distance_type { + let metric = Self::parse_metric_type(Some(dt))?; + scanner.distance_metric(metric); + } + // Apply offset on top of the k nearest results. + if let Some(off) = offset.filter(|&o| o > 0) { + scanner.limit(None, Some(off as i64)).map_err(|e| { + Error::invalid_input_source( + format!("Invalid offset for {}: {}", operation, e).into(), + ) + })?; + } + } else { + // Scalar (non-vector) query: treat k as a row LIMIT. + let limit = if k > 0 { Some(k as i64) } else { None }; + scanner + .limit(limit, offset.map(|o| o as i64)) + .map_err(|e| { + Error::invalid_input_source( + format!("Invalid limit/offset for {}: {}", operation, e).into(), + ) + })?; + } + + if let Some(true) = with_row_id { + scanner.with_row_id(); + } + + Ok(()) + } + /// Retrieve a snapshot of operation metrics. /// /// Returns a HashMap where keys are operation names (e.g., "list_tables", "describe_table") @@ -2967,6 +3109,281 @@ impl LanceNamespace for DirectoryNamespace { Ok(DropTableIndexResponse { transaction_id }) } + async fn list_all_tables(&self, request: ListTablesRequest) -> Result { + // In dir-only mode there are no child namespaces, so all tables live in the + // root directory. This is equivalent to listing the root namespace. + let mut tables = self.list_directory_tables().await?; + Self::apply_pagination(&mut tables, request.page_token, request.limit); + Ok(ListTablesResponse::new(tables)) + } + + async fn restore_table(&self, request: RestoreTableRequest) -> Result { + let version = request.version; + if version < 0 { + return Err(Error::invalid_input_source( + format!( + "Table version for restore_table must be non-negative, got {}", + version + ) + .into(), + )); + } + + let table_uri = self.resolve_table_location(&request.id).await?; + let mut dataset = self.load_dataset(&table_uri, None, "restore_table").await?; + + dataset = dataset + .checkout_version(version as u64) + .await + .map_err(|e| { + Error::namespace_source( + format!( + "Failed to checkout version {} for restore at '{}': {}", + version, table_uri, e + ) + .into(), + ) + })?; + + dataset.restore().await.map_err(|e| { + Error::namespace_source( + format!( + "Failed to restore table at '{}' to version {}: {}", + table_uri, version, e + ) + .into(), + ) + })?; + + let transaction_id = dataset + .read_transaction() + .await + .map_err(|e| { + Error::namespace_source( + format!( + "Failed to read transaction after restoring '{}': {}", + table_uri, e + ) + .into(), + ) + })? + .map(|t| t.uuid); + + Ok(RestoreTableResponse { transaction_id }) + } + + async fn update_table_schema_metadata( + &self, + request: UpdateTableSchemaMetadataRequest, + ) -> Result { + let table_uri = self.resolve_table_location(&request.id).await?; + let mut dataset = self + .load_dataset(&table_uri, None, "update_table_schema_metadata") + .await?; + + let new_metadata = request.metadata.unwrap_or_default(); + let updated_metadata = dataset + .update_schema_metadata(new_metadata.iter().map(|(k, v)| (k.as_str(), v.as_str()))) + .await + .map_err(|e| { + Error::namespace_source( + format!( + "Failed to update schema metadata for table at '{}': {}", + table_uri, e + ) + .into(), + ) + })?; + + let transaction_id = dataset + .read_transaction() + .await + .map_err(|e| { + Error::namespace_source( + format!( + "Failed to read transaction after updating metadata for '{}': {}", + table_uri, e + ) + .into(), + ) + })? + .map(|t| t.uuid); + + Ok(UpdateTableSchemaMetadataResponse { + metadata: Some(updated_metadata), + transaction_id, + }) + } + + async fn get_table_stats( + &self, + request: GetTableStatsRequest, + ) -> Result { + let table_uri = self.resolve_table_location(&request.id).await?; + let dataset = Arc::new( + self.load_dataset(&table_uri, None, "get_table_stats") + .await?, + ); + + // Compute total bytes on disk using field-level statistics + let data_stats = dataset.calculate_data_stats().await.map_err(|e| { + Error::namespace_source( + format!( + "Failed to calculate data statistics for table at '{}': {}", + table_uri, e + ) + .into(), + ) + })?; + let total_bytes: i64 = data_stats + .fields + .iter() + .map(|f| f.bytes_on_disk as i64) + .sum(); + + // Collect per-fragment row counts + let fragment_row_futures: Vec<_> = dataset + .get_fragments() + .into_iter() + .map(|f| async move { f.physical_rows().await }) + .collect(); + let fragment_row_results = futures::future::join_all(fragment_row_futures).await; + let mut fragment_row_counts: Vec = fragment_row_results + .into_iter() + .filter_map(|r| r.ok()) + .map(|r| r as i64) + .collect(); + + let num_fragments = fragment_row_counts.len() as i64; + let num_rows: i64 = fragment_row_counts.iter().sum(); + + // Fragments with fewer rows than the compaction target are considered "small", + // consistent with CompactionOptions::target_rows_per_fragment default. + const SMALL_FRAGMENT_THRESHOLD: i64 = 1024 * 1024; + let num_small_fragments = fragment_row_counts + .iter() + .filter(|&&r| r < SMALL_FRAGMENT_THRESHOLD) + .count() as i64; + + // Compute length summary statistics + fragment_row_counts.sort_unstable(); + let lengths = if fragment_row_counts.is_empty() { + FragmentSummary::new(0, 0, 0, 0, 0, 0, 0) + } else { + let len = fragment_row_counts.len(); + let min = fragment_row_counts[0]; + let max = fragment_row_counts[len - 1]; + let mean = num_rows / num_fragments; + let pct = |p: f64| fragment_row_counts[((len - 1) as f64 * p) as usize]; + FragmentSummary::new(min, max, mean, pct(0.25), pct(0.50), pct(0.75), pct(0.99)) + }; + + // Count non-system indices + let indices = dataset.load_indices().await.map_err(|e| { + Error::namespace_source( + format!("Failed to load indices for table at '{}': {}", table_uri, e).into(), + ) + })?; + let num_indices = indices.iter().filter(|m| !is_system_index(m)).count() as i64; + + let fragment_stats = FragmentStats::new(num_fragments, num_small_fragments, lengths); + Ok(GetTableStatsResponse::new( + total_bytes, + num_rows, + num_indices, + fragment_stats, + )) + } + + async fn explain_table_query_plan( + &self, + request: ExplainTableQueryPlanRequest, + ) -> Result { + let table_uri = self.resolve_table_location(&request.id).await?; + let dataset = self + .load_dataset( + &table_uri, + request.query.version, + "explain_table_query_plan", + ) + .await?; + let verbose = request.verbose.unwrap_or(false); + + let mut scanner = dataset.scan(); + Self::apply_query_params_to_scanner( + &mut scanner, + request.query.filter.as_deref(), + request.query.columns.as_deref(), + request.query.vector_column.as_deref(), + &request.query.vector, + request.query.k, + request.query.offset, + request.query.prefilter, + request.query.bypass_vector_index, + request.query.nprobes, + request.query.ef, + request.query.refine_factor, + request.query.distance_type.as_deref(), + request.query.fast_search, + request.query.with_row_id, + request.query.lower_bound, + request.query.upper_bound, + "explain_table_query_plan", + )?; + + scanner.explain_plan(verbose).await.map_err(|e| { + Error::namespace_source( + format!( + "Failed to explain query plan for table at '{}': {}", + table_uri, e + ) + .into(), + ) + }) + } + + async fn analyze_table_query_plan( + &self, + request: AnalyzeTableQueryPlanRequest, + ) -> Result { + let table_uri = self.resolve_table_location(&request.id).await?; + let dataset = self + .load_dataset(&table_uri, request.version, "analyze_table_query_plan") + .await?; + + let mut scanner = dataset.scan(); + Self::apply_query_params_to_scanner( + &mut scanner, + request.filter.as_deref(), + request.columns.as_deref(), + request.vector_column.as_deref(), + &request.vector, + request.k, + request.offset, + request.prefilter, + request.bypass_vector_index, + request.nprobes, + request.ef, + request.refine_factor, + request.distance_type.as_deref(), + request.fast_search, + request.with_row_id, + request.lower_bound, + request.upper_bound, + "analyze_table_query_plan", + )?; + + scanner.analyze_plan().await.map_err(|e| { + Error::namespace_source( + format!( + "Failed to analyze query plan for table at '{}': {}", + table_uri, e + ) + .into(), + ) + }) + } + async fn count_table_rows(&self, request: CountTableRowsRequest) -> Result { self.record_op("count_table_rows"); let table_uri = self.resolve_table_location(&request.id).await?; @@ -3312,11 +3729,24 @@ mod tests { use lance_core::utils::tempfile::{TempStdDir, TempStrDir}; use lance_namespace::models::{ CreateTableRequest, JsonArrowDataType, JsonArrowField, JsonArrowSchema, ListTablesRequest, + QueryTableRequestColumns, }; use lance_namespace::schema::convert_json_arrow_schema; use std::io::Cursor; use std::sync::Arc; + fn assert_plan_contains_all(plan: &str, expected_fragments: &[&str], context: &str) { + for expected_fragment in expected_fragments { + assert!( + plan.contains(expected_fragment), + "{}. Missing fragment: '{}'. Plan:\n{}", + context, + expected_fragment, + plan + ); + } + } + /// Helper to create a test DirectoryNamespace with a temporary directory async fn create_test_namespace() -> (DirectoryNamespace, TempStdDir) { let temp_dir = TempStdDir::default(); @@ -7734,4 +8164,173 @@ mod tests { assert_eq!(*ver, 2, "Recorded version should be 2"); } } + + #[tokio::test] + async fn test_list_all_tables() { + use lance_namespace::models::ListTablesRequest; + + let (namespace, _temp_dir) = create_test_namespace().await; + create_scalar_table(&namespace, "alpha").await; + create_scalar_table(&namespace, "beta").await; + + let request = ListTablesRequest { + id: Some(vec![]), + page_token: None, + limit: None, + ..Default::default() + }; + let response = namespace.list_all_tables(request).await.unwrap(); + let mut tables = response.tables; + tables.sort(); + assert_eq!(tables, vec!["alpha", "beta"]); + } + + #[tokio::test] + async fn test_restore_table() { + use lance_namespace::models::RestoreTableRequest; + + let (namespace, _temp_dir) = create_test_namespace().await; + create_scalar_table(&namespace, "users").await; + + // Create a second version by creating a scalar index (this adds a new version) + create_scalar_index(&namespace, "users", "users_id_idx").await; + + let dataset = open_dataset(&namespace, "users").await; + let current_version = dataset.version().version; + assert!(current_version >= 2, "Should have at least 2 versions"); + + // Restore to version 1 + let mut restore_req = RestoreTableRequest::new(1); + restore_req.id = Some(vec!["users".to_string()]); + let response = namespace.restore_table(restore_req).await.unwrap(); + + // transaction_id should be present (the restore operation) + assert!( + response.transaction_id.is_some(), + "restore_table should return a transaction_id" + ); + + // Verify the dataset now has a new version (restore creates a new version) + let dataset_after = open_dataset(&namespace, "users").await; + assert!( + dataset_after.version().version > current_version, + "Restore should create a new version" + ); + } + + #[tokio::test] + async fn test_update_table_schema_metadata() { + use lance_namespace::models::UpdateTableSchemaMetadataRequest; + + let (namespace, _temp_dir) = create_test_namespace().await; + create_scalar_table(&namespace, "products").await; + + let mut metadata = HashMap::new(); + metadata.insert("owner".to_string(), "team_a".to_string()); + metadata.insert("version".to_string(), "1.0".to_string()); + + let mut req = UpdateTableSchemaMetadataRequest::new(); + req.id = Some(vec!["products".to_string()]); + req.metadata = Some(metadata.clone()); + + let response = namespace.update_table_schema_metadata(req).await.unwrap(); + + assert!(response.metadata.is_some()); + let returned = response.metadata.unwrap(); + assert_eq!(returned.get("owner"), Some(&"team_a".to_string())); + assert_eq!(returned.get("version"), Some(&"1.0".to_string())); + assert!( + response.transaction_id.is_some(), + "update_table_schema_metadata should return a transaction_id" + ); + } + + #[tokio::test] + async fn test_get_table_stats() { + use lance_namespace::models::GetTableStatsRequest; + + let (namespace, _temp_dir) = create_test_namespace().await; + create_scalar_table(&namespace, "items").await; + create_scalar_index(&namespace, "items", "items_id_idx").await; + + let mut req = GetTableStatsRequest::new(); + req.id = Some(vec!["items".to_string()]); + + let response = namespace.get_table_stats(req).await.unwrap(); + assert_eq!(response.num_rows, 3); + assert_eq!(response.num_indices, 1); + } + + #[tokio::test] + async fn test_explain_table_query_plan() { + use lance_namespace::models::QueryTableRequestVector; + use lance_namespace::models::{ExplainTableQueryPlanRequest, QueryTableRequest}; + + let (namespace, _temp_dir) = create_test_namespace().await; + create_scalar_table(&namespace, "catalog").await; + + let mut query = QueryTableRequest::new(1, QueryTableRequestVector::new()); + query.filter = Some("id > 1".to_string()); + query.columns = Some(Box::new(QueryTableRequestColumns { + column_names: Some(vec!["id".to_string(), "name".to_string()]), + column_aliases: None, + })); + query.with_row_id = Some(true); + + let mut req = ExplainTableQueryPlanRequest::new(query); + req.id = Some(vec!["catalog".to_string()]); + + let plan_str = namespace.explain_table_query_plan(req).await.unwrap(); + assert_plan_contains_all( + &plan_str, + &[ + "ProjectionExec: expr=[id@0 as id, name@2 as name", + "Take: columns=\"id, _rowid, (name)\"", + "LanceRead: uri=", + "projection=[id]", + "row_id=true, row_addr=false", + "full_filter=id > Int32(1)", + "refine_filter=id > Int32(1)", + ], + "Filtered explain plan should preserve late materialization and filter pushdown", + ); + } + + #[tokio::test] + async fn test_analyze_table_query_plan() { + use lance_namespace::models::AnalyzeTableQueryPlanRequest; + use lance_namespace::models::QueryTableRequestVector; + + let (namespace, _temp_dir) = create_test_namespace().await; + create_scalar_table(&namespace, "catalog").await; + + let mut req = AnalyzeTableQueryPlanRequest::new(1, QueryTableRequestVector::new()); + req.id = Some(vec!["catalog".to_string()]); + req.filter = Some("id > 0".to_string()); + req.columns = Some(Box::new(QueryTableRequestColumns { + column_names: Some(vec!["id".to_string(), "name".to_string()]), + column_aliases: None, + })); + req.with_row_id = Some(true); + + let analysis_str = namespace.analyze_table_query_plan(req).await.unwrap(); + assert_plan_contains_all( + &analysis_str, + &[ + "AnalyzeExec verbose=true", + "ProjectionExec: elapsed=", + "expr=[id@0 as id, name@2 as name", + "Take: elapsed=", + "columns=\"id, _rowid, (name)\"", + "CoalesceBatchesExec: elapsed=", + "LanceRead: elapsed=", + "projection=[id]", + "row_id=true, row_addr=false", + "full_filter=id > Int32(0)", + "refine_filter=id > Int32(0)", + "metrics=[output_rows=", + ], + "Filtered analyze plan should preserve late materialization and filter pushdown", + ); + } }