From 987246f4105fc466a5b542b4a1889bdbbc74f830 Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Wed, 15 Oct 2025 12:45:23 -0700 Subject: [PATCH 1/7] rebase with just bkd leaf --- protos/index.proto | 4 +- protos/index_old.proto | 1 + python/python/lance/dataset.py | 27 +- python/src/dataset.rs | 5 + rust/lance-datafusion/src/udf.rs | 108 +++ rust/lance-index/src/lib.rs | 7 +- rust/lance-index/src/scalar.rs | 43 + rust/lance-index/src/scalar/expression.rs | 120 ++- rust/lance-index/src/scalar/geoindex.rs | 1001 +++++++++++++++++++++ rust/lance-index/src/scalar/registry.rs | 7 +- rust/lance/src/dataset/sql.rs | 2 + rust/lance/src/index/create.rs | 3 +- test_geoarrow_geo_index.py | 253 ++++++ 13 files changed, 1568 insertions(+), 13 deletions(-) create mode 100644 rust/lance-index/src/scalar/geoindex.rs create mode 100644 test_geoarrow_geo_index.py diff --git a/protos/index.proto b/protos/index.proto index c6d6370f906..837e54abdcc 100644 --- a/protos/index.proto +++ b/protos/index.proto @@ -188,4 +188,6 @@ message JsonIndexDetails { string path = 1; google.protobuf.Any target_details = 2; } -message BloomFilterIndexDetails {} \ No newline at end of file +message BloomFilterIndexDetails {} + +message GeoIndexDetails {} \ No newline at end of file diff --git a/protos/index_old.proto b/protos/index_old.proto index 601aa2681da..5931f911380 100644 --- a/protos/index_old.proto +++ b/protos/index_old.proto @@ -25,6 +25,7 @@ message BitmapIndexDetails {} message LabelListIndexDetails {} message NGramIndexDetails {} message ZoneMapIndexDetails {} +message GeoIndexDetails {} message InvertedIndexDetails { // Marking this field as optional as old versions of the index store blank details and we // need to make sure we have a proper optional field to detect this. diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index e0b7f638846..b966d19fbfb 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2345,6 +2345,7 @@ def create_scalar_index( "LABEL_LIST", "INVERTED", "BLOOMFILTER", + "GEO" ]: raise NotImplementedError( ( @@ -2392,11 +2393,27 @@ def create_scalar_index( f"INVERTED index column {column} must be string, large string" " or list of strings, but got {value_type}" ) - - if pa.types.is_duration(field_type): - raise TypeError( - f"Scalar index column {column} cannot currently be a duration" - ) + elif index_type == "GEO": + # Accept struct for GeoArrow point data + if pa.types.is_struct(field_type): + field_names = [field.name for field in field_type] + if set(field_names) == {"x", "y"}: + # This is geoarrow point data - allow it + pass + else: + raise TypeError( + f"GEO index column {column} must be a struct with x,y fields for point data. " + f"Got struct with fields: {field_names}" + ) + else: + raise TypeError( + f"GEO index column {column} must be a struct type. " + f"Got field type: {field_type}" + ) + if pa.types.is_duration(field_type): + raise TypeError( + f"Scalar index column {column} cannot currently be a duration" + ) elif isinstance(index_type, IndexConfig): config = json.dumps(index_type.parameters) kwargs["config"] = indices.IndexConfig(index_type.index_type, config) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 8c8f086ff9b..64c027dc984 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1601,6 +1601,7 @@ impl Dataset { "BLOOMFILTER" => IndexType::BloomFilter, "LABEL_LIST" => IndexType::LabelList, "INVERTED" | "FTS" => IndexType::Inverted, + "GEO" => IndexType::Geo, "IVF_FLAT" | "IVF_PQ" | "IVF_SQ" | "IVF_RQ" | "IVF_HNSW_FLAT" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, _ => { @@ -1702,6 +1703,10 @@ impl Dataset { } Box::new(params) } + "GEO" => Box::new(ScalarIndexParams { + index_type: "geo".to_string(), + params: None, + }), _ => { let column_type = match self.ds.schema().field(columns[0]) { Some(f) => f.data_type().clone(), diff --git a/rust/lance-datafusion/src/udf.rs b/rust/lance-datafusion/src/udf.rs index 24366077c66..6af1a81f362 100644 --- a/rust/lance-datafusion/src/udf.rs +++ b/rust/lance-datafusion/src/udf.rs @@ -26,8 +26,116 @@ pub fn register_functions(ctx: &SessionContext) { ctx.register_udf(json::json_get_bool_udf()); ctx.register_udf(json::json_array_contains_udf()); ctx.register_udf(json::json_array_length_udf()); + + // GEO functions + ctx.register_udf(ST_INTERSECTS_UDF.clone()); + ctx.register_udf(ST_WITHIN_UDF.clone()); + ctx.register_udf(BBOX_UDF.clone()); +} + +static ST_WITHIN_UDF: LazyLock = LazyLock::new(st_within); +static BBOX_UDF: LazyLock = LazyLock::new(bbox); +static ST_INTERSECTS_UDF: LazyLock = LazyLock::new(st_intersects); + +fn st_intersects() -> ScalarUDF { + let function = Arc::new(make_scalar_function( + |_args: &[ArrayRef]| { + // Throw an error indicating that a spatial index is required + Err(datafusion::error::DataFusionError::Execution( + "st_intersects requires a spatial index. Please create a spatial index on the geometry column using dataset.create_scalar_index(column='your_column', index_type='GEO') before running spatial queries.".to_string(), + )) + }, + vec![], + )); + + create_udf( + "st_intersects", + vec![ + DataType::Struct(vec![ + Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), + ].into()), + DataType::Struct(vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ].into()) + ], // GeoArrow Point struct, GeoArrow Box struct + DataType::Boolean, + Volatility::Immutable, + function, + ) +} + + +fn st_within() -> ScalarUDF { + let function = Arc::new(make_scalar_function( + |_args: &[ArrayRef]| { + // Throw an error indicating that a spatial index is required + Err(datafusion::error::DataFusionError::Execution( + "st_within requires a spatial index. Please create a spatial index on the geometry column using dataset.create_scalar_index(column='your_column', index_type='RTREE') before running spatial queries.".to_string(), + )) + }, + vec![], + )); + + create_udf( + "st_within", + vec![ + DataType::Struct(vec![ + Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), + ].into()), + DataType::Struct(vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ].into()) + ], // GeoArrow Point struct, GeoArrow Box struct + DataType::Boolean, + Volatility::Immutable, + function, + ) } + +/// BBOX function that creates a bounding box from four numeric arguments. +/// This function is used internally by spatial queries and doesn't perform actual computation. +/// It's intercepted by Lance's geo query parser for index optimization. +/// +/// Usage in SQL: +/// ```sql +/// SELECT * FROM table WHERE ST_Intersects(geometry_column, BBOX(-180, -90, 180, 90)) +/// ``` +fn bbox() -> ScalarUDF { + let function = Arc::new(make_scalar_function( + |_args: &[ArrayRef]| { + // This UDF should never be called because BBOX functions are intercepted by the query parser + // If this executes, it means no spatial index exists + Err(datafusion::error::DataFusionError::Execution( + "BBOX function requires a spatial index. Please create a spatial index on the geometry column using dataset.create_scalar_index(column='your_column', index_type='RTREE') before running spatial queries.".to_string(), + )) + }, + vec![], + )); + + create_udf( + "bbox", + vec![DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64], // min_x, min_y, max_x, max_y + DataType::Struct(vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ].into()), // Returns a GeoArrow Box struct + Volatility::Immutable, + function, + ) +} + + /// This method checks whether a string contains all specified tokens. The tokens are separated by /// punctuations and white spaces. /// diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 26184cd47ff..7e2e2aae9ec 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -108,6 +108,8 @@ pub enum IndexType { BloomFilter = 9, // Bloom filter + Geo = 10, // Geo + // 100+ and up for vector index. /// Flat vector index. Vector = 100, // Legacy vector index, alias to IvfPq @@ -130,6 +132,7 @@ impl std::fmt::Display for IndexType { Self::NGram => write!(f, "NGram"), Self::FragmentReuse => write!(f, "FragmentReuse"), Self::MemWal => write!(f, "MemWal"), + Self::Geo => write!(f, "Geo"), Self::ZoneMap => write!(f, "ZoneMap"), Self::BloomFilter => write!(f, "BloomFilter"), Self::Vector | Self::IvfPq => write!(f, "IVF_PQ"), @@ -156,6 +159,7 @@ impl TryFrom for IndexType { v if v == Self::Inverted as i32 => Ok(Self::Inverted), v if v == Self::FragmentReuse as i32 => Ok(Self::FragmentReuse), v if v == Self::MemWal as i32 => Ok(Self::MemWal), + v if v == Self::Geo as i32 => Ok(Self::Geo), v if v == Self::ZoneMap as i32 => Ok(Self::ZoneMap), v if v == Self::BloomFilter as i32 => Ok(Self::BloomFilter), v if v == Self::Vector as i32 => Ok(Self::Vector), @@ -214,6 +218,7 @@ impl IndexType { | Self::NGram | Self::ZoneMap | Self::BloomFilter + | Self::Geo ) } @@ -252,7 +257,7 @@ impl IndexType { Self::MemWal => 0, Self::ZoneMap => 0, Self::BloomFilter => 0, - + Self::Geo => 0, // for now all vector indices are built by the same builder, // so they share the same version. Self::Vector diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 69b5ee35cf0..2d78847a3e1 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -33,6 +33,7 @@ pub mod bloomfilter; pub mod btree; pub mod expression; pub mod flat; +pub mod geoindex; pub mod inverted; pub mod json; pub mod label_list; @@ -61,6 +62,7 @@ pub enum BuiltinIndexType { ZoneMap, BloomFilter, Inverted, + Geo, } impl BuiltinIndexType { @@ -73,6 +75,7 @@ impl BuiltinIndexType { Self::ZoneMap => "zonemap", Self::Inverted => "inverted", Self::BloomFilter => "bloomfilter", + Self::Geo => "geo", } } } @@ -89,6 +92,7 @@ impl TryFrom for BuiltinIndexType { IndexType::ZoneMap => Ok(Self::ZoneMap), IndexType::Inverted => Ok(Self::Inverted), IndexType::BloomFilter => Ok(Self::BloomFilter), + IndexType::Geo => Ok(Self::Geo), _ => Err(Error::Index { message: "Invalid index type".to_string(), location: location!(), @@ -587,6 +591,45 @@ pub enum TokenQuery { TokensContains(String), } +/// A query that a GeoIndex can satisfy +#[derive(Debug, Clone, PartialEq)] +pub enum GeoQuery { + /// Retrieve all row ids where the geometry intersects with the given bounding box + /// Format: (min_x, min_y, max_x, max_y) + Intersects(f64, f64, f64, f64), +} + +impl AnyQuery for GeoQuery { + fn as_any(&self) -> &dyn Any { + self + } + + fn format(&self, col: &str) -> String { + match self { + Self::Intersects(min_x, min_y, max_x, max_y) => { + format!("st_intersects({}, bbox({}, {}, {}, {}))", col, min_x, min_y, max_x, max_y) + } + } + } + + fn to_expr(&self, _col: String) -> Expr { + match self { + Self::Intersects(_min_x, _min_y, _max_x, _max_y) => { + // For now, return a placeholder expression + // This would need to be a proper st_intersects UDF call + Expr::Literal(ScalarValue::Boolean(Some(true)), None) + } + } + } + + fn dyn_eq(&self, other: &dyn AnyQuery) -> bool { + match other.as_any().downcast_ref::() { + Some(o) => self == o, + None => false, + } + } +} + /// A query that a BloomFilter index can satisfy /// /// This is a subset of SargableQuery that only includes operations that bloom filters diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 047dddeaa93..26c000b1ec8 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -18,8 +18,8 @@ use datafusion_expr::{ }; use super::{ - AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, - SearchResult, TextQuery, TokenQuery, + AnyQuery, BloomFilterQuery, GeoQuery, LabelListQuery, MetricsCollector, SargableQuery, + ScalarIndex, SearchResult, TextQuery, TokenQuery, }; use futures::join; use lance_core::{utils::mask::RowIdMask, Error, Result}; @@ -665,6 +665,122 @@ impl ScalarQueryParser for FtsQueryParser { } } +/// A parser for geo indices that handles spatial queries +#[derive(Debug, Clone)] +pub struct GeoQueryParser { + index_name: String, +} + +impl GeoQueryParser { + pub fn new(index_name: String) -> Self { + Self { index_name } + } + + /// Extract bounding box coordinates from a bbox() function call + /// Expected format: bbox(min_x, min_y, max_x, max_y) + fn extract_bbox(&self, expr: &Expr) -> Option<(f64, f64, f64, f64)> { + match expr { + Expr::ScalarFunction(ScalarFunction { func, args }) => { + if func.name() == "bbox" && args.len() == 4 { + // Extract the four coordinates + let min_x = maybe_scalar(&args[0], &DataType::Float64)?; + let min_y = maybe_scalar(&args[1], &DataType::Float64)?; + let max_x = maybe_scalar(&args[2], &DataType::Float64)?; + let max_y = maybe_scalar(&args[3], &DataType::Float64)?; + + // Convert to f64 + let min_x = match min_x { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + let min_y = match min_y { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + let max_x = match max_x { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + let max_y = match max_y { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + + return Some((min_x, min_y, max_x, max_y)); + } + None + } + _ => None, + } + } +} + +impl ScalarQueryParser for GeoQueryParser { + fn visit_between( + &self, + _: &str, + _: &Bound, + _: &Bound, + ) -> Option { + None + } + + fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { + None + } + + fn visit_is_bool(&self, _: &str, _: bool) -> Option { + None + } + + fn visit_is_null(&self, _: &str) -> Option { + None + } + + fn visit_comparison( + &self, + _: &str, + _: &ScalarValue, + _: &Operator, + ) -> Option { + None + } + + fn visit_scalar_function( + &self, + column: &str, + _data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + // Handle st_intersects(geometry_column, bbox(...)) + if func.name() == "st_intersects" && args.len() == 2 { + // The second argument should be a bbox() call + if let Some((min_x, min_y, max_x, max_y)) = self.extract_bbox(&args[1]) { + let query = GeoQuery::Intersects(min_x, min_y, max_x, max_y); + return Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + Arc::new(query), + )); + } + } + None + } +} + impl IndexedExpression { /// Create an expression that only does refine fn refine_only(refine_expr: Expr) -> Self { diff --git a/rust/lance-index/src/scalar/geoindex.rs b/rust/lance-index/src/scalar/geoindex.rs new file mode 100644 index 00000000000..f8b4d7411cf --- /dev/null +++ b/rust/lance-index/src/scalar/geoindex.rs @@ -0,0 +1,1001 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Geo Index +//! +//! Geo indices are spatial database structures for efficient spatial queries. +//! They enable efficient filtering by location-based predicates. +//! +//! ## Requirements +//! +//! Geo indices can only be created on fields with GeoArrow metadata. The field must: +//! - Be a Struct data type +//! - Have `ARROW:extension:name` metadata starting with `geoarrow.` (e.g., `geoarrow.point`, `geoarrow.polygon`) +//! +//! ## Query Support +//! +//! Geo indices are "inexact" filters - they can definitively exclude regions but may include +//! false positives that require rechecking. +//! + +use crate::pbold; +use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser}; +use crate::scalar::registry::{ + ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest, +}; +use crate::scalar::{ + BuiltinIndexType, CreatedIndex, GeoQuery, ScalarIndexParams, UpdateCriteria, +}; +use crate::Any; +use futures::TryStreamExt; +use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; +use lance_core::{ROW_ADDR, ROW_ID}; +use serde::{Deserialize, Serialize}; + +use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt64Array, UInt8Array}; +use arrow_array::cast::AsArray; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::execution::SendableRecordBatchStream; +use std::{collections::HashMap, sync::Arc}; + +use super::{AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; +use crate::scalar::FragReuseIndex; +use crate::vector::VectorIndex; +use crate::{Index, IndexType}; +use async_trait::async_trait; +use deepsize::DeepSizeOf; +use lance_core::Result; +use lance_core::{utils::mask::RowIdTreeMap, Error}; +use roaring::RoaringBitmap; +use snafu::location; + +const BKD_TREE_FILENAME: &str = "bkd_tree.lance"; +const BKD_LEAVES_FILENAME: &str = "bkd_leaves.lance"; +const GEO_INDEX_VERSION: u32 = 0; +const LEAF_SIZE_META_KEY: &str = "leaf_size"; +const DEFAULT_LEAF_SIZE: u32 = 4096; + +/// BKD Tree node representing either an inner node or a leaf +#[derive(Debug, Clone, DeepSizeOf)] +struct BKDNode { + /// Bounding box: [min_x, min_y, max_x, max_y] + bounds: [f64; 4], + /// Split dimension: 0=X, 1=Y + split_dim: u8, + /// Split value along split_dim + split_value: f64, + /// Left child node index (for inner nodes) + left_child: Option, + /// Right child node index (for inner nodes) + right_child: Option, + /// Leaf ID in bkd_leaves.lance (for leaf nodes) + leaf_id: Option, +} + +/// In-memory BKD tree structure for efficient spatial queries +#[derive(Debug, DeepSizeOf)] +struct BKDTreeLookup { + nodes: Vec, + root_id: u32, + num_leaves: u32, +} + +impl BKDTreeLookup { + fn new(nodes: Vec, root_id: u32, num_leaves: u32) -> Self { + Self { + nodes, + root_id, + num_leaves, + } + } + + /// Find all leaf IDs that intersect with the query bounding box + fn find_intersecting_leaves(&self, query_bbox: [f64; 4]) -> Vec { + let mut leaf_ids = Vec::new(); + let mut stack = vec![self.root_id]; + + while let Some(node_id) = stack.pop() { + if node_id as usize >= self.nodes.len() { + continue; + } + + let node = &self.nodes[node_id as usize]; + + // Check if node's bounding box intersects with query bbox + if !bboxes_intersect(&node.bounds, &query_bbox) { + continue; + } + + // If this is a leaf node, add its leaf_id + if let Some(leaf_id) = node.leaf_id { + leaf_ids.push(leaf_id); + } else { + // Inner node - traverse children + if let Some(left) = node.left_child { + stack.push(left); + } + if let Some(right) = node.right_child { + stack.push(right); + } + } + } + + leaf_ids + } + + /// Deserialize from RecordBatch + fn from_record_batch(batch: RecordBatch) -> Result { + if batch.num_rows() == 0 { + return Ok(Self::new(vec![], 0, 0)); + } + + let min_x = batch.column(0).as_primitive::(); + let min_y = batch.column(1).as_primitive::(); + let max_x = batch.column(2).as_primitive::(); + let max_y = batch.column(3).as_primitive::(); + let split_dim = batch.column(4).as_primitive::(); + let split_value = batch.column(5).as_primitive::(); + let left_child = batch.column(6).as_primitive::(); + let right_child = batch.column(7).as_primitive::(); + let leaf_id = batch.column(8).as_primitive::(); + + let mut nodes = Vec::with_capacity(batch.num_rows()); + let mut num_leaves = 0; + + for i in 0..batch.num_rows() { + let leaf_id_val = if leaf_id.is_null(i) { + None + } else { + num_leaves += 1; + Some(leaf_id.value(i)) + }; + + nodes.push(BKDNode { + bounds: [ + min_x.value(i), + min_y.value(i), + max_x.value(i), + max_y.value(i), + ], + split_dim: split_dim.value(i), + split_value: split_value.value(i), + left_child: if left_child.is_null(i) { + None + } else { + Some(left_child.value(i)) + }, + right_child: if right_child.is_null(i) { + None + } else { + Some(right_child.value(i)) + }, + leaf_id: leaf_id_val, + }); + } + + Ok(Self::new(nodes, 0, num_leaves)) + } +} + +/// Check if two bounding boxes intersect +fn bboxes_intersect(bbox1: &[f64; 4], bbox2: &[f64; 4]) -> bool { + // bbox format: [min_x, min_y, max_x, max_y] + !(bbox1[2] < bbox2[0] || bbox1[0] > bbox2[2] || bbox1[3] < bbox2[1] || bbox1[1] > bbox2[3]) +} + +/// Check if a point is within a bounding box +fn point_in_bbox(x: f64, y: f64, bbox: &[f64; 4]) -> bool { + x >= bbox[0] && x <= bbox[2] && y >= bbox[1] && y <= bbox[3] +} + +/// Lazy reader for BKD leaf file +#[derive(Clone)] +struct LazyIndexReader { + index_reader: Arc>>>, + store: Arc, + filename: String, +} + +impl LazyIndexReader { + fn new(store: Arc, filename: &str) -> Self { + Self { + index_reader: Arc::new(tokio::sync::Mutex::new(None)), + store, + filename: filename.to_string(), + } + } + + async fn get(&self) -> Result> { + let mut reader = self.index_reader.lock().await; + if reader.is_none() { + let r = self.store.open_index_file(&self.filename).await?; + *reader = Some(r); + } + Ok(reader.as_ref().unwrap().clone()) + } +} + +/// Cache key for BKD leaf nodes +#[derive(Debug, Clone)] +struct BKDLeafKey { + leaf_id: u32, +} + +impl CacheKey for BKDLeafKey { + type ValueType = CachedLeafData; + + fn key(&self) -> std::borrow::Cow<'_, str> { + format!("bkd-leaf-{}", self.leaf_id).into() + } +} + +/// Cached leaf data +#[derive(Debug, Clone)] +struct CachedLeafData(RecordBatch); + +impl DeepSizeOf for CachedLeafData { + fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { + // Approximate size of RecordBatch + self.0.get_array_memory_size() + } +} + +impl CachedLeafData { + fn new(batch: RecordBatch) -> Self { + Self(batch) + } + + fn into_inner(self) -> RecordBatch { + self.0 + } +} + +/// Geo index +pub struct GeoIndex { + data_type: DataType, + store: Arc, + fri: Option>, + index_cache: WeakLanceCache, + bkd_tree: Arc, + leaf_size: u32, +} + +impl std::fmt::Debug for GeoIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GeoIndex") + .field("data_type", &self.data_type) + .field("store", &self.store) + .field("fri", &self.fri) + .field("index_cache", &self.index_cache) + .field("bkd_tree", &self.bkd_tree) + .field("leaf_size", &self.leaf_size) + .finish() + } +} + +impl DeepSizeOf for GeoIndex { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.bkd_tree.deep_size_of_children(context) + self.store.deep_size_of_children(context) + } +} + +impl GeoIndex { + /// Load the geo index from storage + async fn load( + store: Arc, + fri: Option>, + index_cache: &LanceCache, + ) -> Result> + where + Self: Sized, + { + // Load BKD tree structure (inner nodes) + let tree_file = store.open_index_file(BKD_TREE_FILENAME).await?; + let tree_data = tree_file + .read_range(0..tree_file.num_rows(), None) + .await?; + + // Deserialize tree structure + let bkd_tree = BKDTreeLookup::from_record_batch(tree_data)?; + + // Extract metadata + let schema = tree_file.schema(); + let leaf_size = schema + .metadata + .get(LEAF_SIZE_META_KEY) + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_LEAF_SIZE); + + // Get data type from schema + let data_type = schema.fields[0].data_type().clone(); + + Ok(Arc::new(Self { + data_type, + store, + fri, + index_cache: WeakLanceCache::from(index_cache), + bkd_tree: Arc::new(bkd_tree), + leaf_size, + })) + } + + /// Load a specific leaf from storage + async fn load_leaf( + &self, + leaf_id: u32, + index_reader: LazyIndexReader, + metrics: &dyn MetricsCollector, + ) -> Result { + // Check cache first + let cache_key = BKDLeafKey { leaf_id }; + + let cached = self + .index_cache + .get_or_insert_with_key(cache_key, move || async move { + metrics.record_part_load(); + let reader = index_reader.get().await?; + let batch = reader + .read_record_batch(leaf_id as u64, self.leaf_size as u64) + .await?; + Ok(CachedLeafData::new(batch)) + }) + .await?; + + Ok(cached.as_ref().clone().into_inner()) + } + + /// Search a specific leaf for points within the query bbox + async fn search_leaf( + &self, + leaf_id: u32, + query_bbox: [f64; 4], + index_reader: LazyIndexReader, + metrics: &dyn MetricsCollector, + ) -> Result { + let leaf_data = self.load_leaf(leaf_id, index_reader, metrics).await?; + + // Filter points within this leaf + let mut row_ids = RowIdTreeMap::new(); + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + let x = x_array.value(i); + let y = y_array.value(i); + + if point_in_bbox(x, y, &query_bbox) { + let row_id = row_id_array.value(i); + row_ids.insert(row_id); + } + } + + Ok(row_ids) + } +} + +#[async_trait] +impl Index for GeoIndex { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_index(self: Arc) -> Arc { + self + } + + fn as_vector_index(self: Arc) -> Result> { + Err(Error::InvalidInput { + source: "GeoIndex is not a vector index".into(), + location: location!(), + }) + } + + async fn prewarm(&self) -> Result<()> { + Ok(()) + } + + fn statistics(&self) -> Result { + Ok(serde_json::json!({ + "type": "geo", + })) + } + + fn index_type(&self) -> IndexType { + IndexType::Geo + } + + async fn calculate_included_frags(&self) -> Result { + let frag_ids = RoaringBitmap::new(); + Ok(frag_ids) + } +} + +#[async_trait] +impl ScalarIndex for GeoIndex { + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { + let geo_query = query.as_any().downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Geo index only supports GeoQuery".into(), + location: location!(), + })?; + + match geo_query { + GeoQuery::Intersects(min_x, min_y, max_x, max_y) => { + let query_bbox = [*min_x, *min_y, *max_x, *max_y]; + + log::debug!( + "Geo index search: st_intersects with bbox({}, {}, {}, {})", + min_x, min_y, max_x, max_y + ); + + // Step 1: Find intersecting leaves using in-memory tree traversal + let leaf_ids = self.bkd_tree.find_intersecting_leaves(query_bbox); + + log::debug!( + "BKD tree found {} intersecting leaves", + leaf_ids.len() + ); + + // Step 2: Lazy-load and filter each leaf + let mut all_row_ids = RowIdTreeMap::new(); + let lazy_reader = LazyIndexReader::new(self.store.clone(), BKD_LEAVES_FILENAME); + + for leaf_id in leaf_ids { + let leaf_row_ids = self + .search_leaf(leaf_id, query_bbox, lazy_reader.clone(), metrics) + .await?; + // Collect row IDs from the leaf and add them to the result set + let row_ids: Option> = leaf_row_ids.row_ids() + .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); + if let Some(row_ids) = row_ids { + all_row_ids.extend(row_ids); + } + } + + log::debug!( + "Geo index returning {:?} row IDs", + all_row_ids.len() + ); + + // We return Exact because we already filtered points in search_leaf + Ok(SearchResult::Exact(all_row_ids)) + } + } + } + + fn can_remap(&self) -> bool { + false + } + + /// Remap the row ids, creating a new remapped version of this index in `dest_store` + async fn remap( + &self, + _mapping: &HashMap>, + _dest_store: &dyn IndexStore, + ) -> Result { + Err(Error::InvalidInput { + source: "GeoIndex does not support remap".into(), + location: location!(), + }) + } + + /// Add the new data , creating an updated version of the index in `dest_store` + async fn update( + &self, + _new_data: SendableRecordBatchStream, + _dest_store: &dyn IndexStore, + ) -> Result { + Err(Error::InvalidInput { + source: "GeoIndex does not support update".into(), + location: location!(), + }) + } + + fn update_criteria(&self) -> UpdateCriteria { + UpdateCriteria::only_new_data( + TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), + ) + } + + fn derive_index_params(&self) -> Result { + let params = serde_json::to_value(GeoIndexBuilderParams::default())?; + Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::Geo).with_params(¶ms)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeoIndexBuilderParams { + #[serde(default = "default_leaf_size")] + pub leaf_size: u32, +} + +fn default_leaf_size() -> u32 { + DEFAULT_LEAF_SIZE +} + +impl Default for GeoIndexBuilderParams { + fn default() -> Self { + Self { + leaf_size: default_leaf_size(), + } + } +} + +impl GeoIndexBuilderParams { + pub fn new() -> Self { + Self::default() + } + + pub fn with_leaf_size(mut self, leaf_size: u32) -> Self { + self.leaf_size = leaf_size; + self + } +} + +// A builder for geo index +pub struct GeoIndexBuilder { + options: GeoIndexBuilderParams, + items_type: DataType, + // Accumulated points: (x, y, row_id) + points: Vec<(f64, f64, u64)>, +} + +impl GeoIndexBuilder { + pub fn try_new(options: GeoIndexBuilderParams, items_type: DataType) -> Result { + Ok(Self { + options, + items_type, + points: Vec::new(), + }) + } + + pub async fn train(&mut self, batches_source: SendableRecordBatchStream) -> Result<()> { + assert!(batches_source.schema().field_with_name(ROW_ADDR).is_ok()); + + let mut batches_source = batches_source; + + while let Some(batch) = batches_source.try_next().await? { + // Extract GeoArrow point coordinates + let geom_array = batch.column(0).as_any().downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Expected Struct array for GeoArrow data".into(), + location: location!(), + })?; + + let x_array = geom_array + .column(0) + .as_primitive::(); + let y_array = geom_array + .column(1) + .as_primitive::(); + let row_ids = batch + .column_by_name(ROW_ADDR) + .unwrap() + .as_primitive::(); + + for i in 0..batch.num_rows() { + self.points.push(( + x_array.value(i), + y_array.value(i), + row_ids.value(i), + )); + } + } + + log::debug!("Accumulated {} points for BKD tree", self.points.len()); + + Ok(()) + } + + pub async fn write_index(mut self, index_store: &dyn IndexStore) -> Result<()> { + if self.points.is_empty() { + // Write empty index files + self.write_empty_index(index_store).await?; + return Ok(()); + } + + // Build BKD tree + let (tree_nodes, leaf_batches) = self.build_bkd_tree()?; + + // Write tree structure + let tree_batch = self.serialize_tree_nodes(&tree_nodes)?; + let mut tree_file = index_store + .new_index_file(BKD_TREE_FILENAME, tree_batch.schema()) + .await?; + tree_file.write_record_batch(tree_batch).await?; + tree_file + .finish_with_metadata(HashMap::from([( + LEAF_SIZE_META_KEY.to_string(), + self.options.leaf_size.to_string(), + )])) + .await?; + + // Write leaf data + let leaf_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + let mut leaf_file = index_store + .new_index_file(BKD_LEAVES_FILENAME, leaf_schema) + .await?; + for leaf_batch in leaf_batches { + leaf_file.write_record_batch(leaf_batch).await?; + } + leaf_file.finish().await?; + + log::debug!( + "Wrote BKD tree with {} nodes", + tree_nodes.len() + ); + + Ok(()) + } + + async fn write_empty_index(&self, index_store: &dyn IndexStore) -> Result<()> { + // Write empty tree file + let tree_schema = Arc::new(Schema::new(vec![ + Field::new("min_x", DataType::Float64, false), + Field::new("min_y", DataType::Float64, false), + Field::new("max_x", DataType::Float64, false), + Field::new("max_y", DataType::Float64, false), + Field::new("split_dim", DataType::UInt8, false), + Field::new("split_value", DataType::Float64, false), + Field::new("left_child", DataType::UInt32, true), + Field::new("right_child", DataType::UInt32, true), + Field::new("leaf_id", DataType::UInt32, true), + ])); + + let empty_batch = RecordBatch::new_empty(tree_schema); + let mut tree_file = index_store + .new_index_file(BKD_TREE_FILENAME, empty_batch.schema()) + .await?; + tree_file.write_record_batch(empty_batch).await?; + tree_file.finish().await?; + + // Write empty leaves file + let leaf_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + let empty_leaves = RecordBatch::new_empty(leaf_schema); + let mut leaf_file = index_store + .new_index_file(BKD_LEAVES_FILENAME, empty_leaves.schema()) + .await?; + leaf_file.write_record_batch(empty_leaves).await?; + leaf_file.finish().await?; + + Ok(()) + } + + fn serialize_tree_nodes(&self, nodes: &[BKDNode]) -> Result { + let mut min_x_vals = Vec::with_capacity(nodes.len()); + let mut min_y_vals = Vec::with_capacity(nodes.len()); + let mut max_x_vals = Vec::with_capacity(nodes.len()); + let mut max_y_vals = Vec::with_capacity(nodes.len()); + let mut split_dim_vals = Vec::with_capacity(nodes.len()); + let mut split_value_vals = Vec::with_capacity(nodes.len()); + let mut left_child_vals = Vec::with_capacity(nodes.len()); + let mut right_child_vals = Vec::with_capacity(nodes.len()); + let mut leaf_id_vals = Vec::with_capacity(nodes.len()); + + for node in nodes { + min_x_vals.push(node.bounds[0]); + min_y_vals.push(node.bounds[1]); + max_x_vals.push(node.bounds[2]); + max_y_vals.push(node.bounds[3]); + split_dim_vals.push(node.split_dim); + split_value_vals.push(node.split_value); + left_child_vals.push(node.left_child); + right_child_vals.push(node.right_child); + leaf_id_vals.push(node.leaf_id); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("min_x", DataType::Float64, false), + Field::new("min_y", DataType::Float64, false), + Field::new("max_x", DataType::Float64, false), + Field::new("max_y", DataType::Float64, false), + Field::new("split_dim", DataType::UInt8, false), + Field::new("split_value", DataType::Float64, false), + Field::new("left_child", DataType::UInt32, true), + Field::new("right_child", DataType::UInt32, true), + Field::new("leaf_id", DataType::UInt32, true), + ])); + + let columns: Vec = vec![ + Arc::new(Float64Array::from(min_x_vals)), + Arc::new(Float64Array::from(min_y_vals)), + Arc::new(Float64Array::from(max_x_vals)), + Arc::new(Float64Array::from(max_y_vals)), + Arc::new(UInt8Array::from(split_dim_vals)), + Arc::new(Float64Array::from(split_value_vals)), + Arc::new(UInt32Array::from(left_child_vals)), + Arc::new(UInt32Array::from(right_child_vals)), + Arc::new(UInt32Array::from(leaf_id_vals)), + ]; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + // Build BKD tree using Lucene's algorithm (deferred for next step) + fn build_bkd_tree(&mut self) -> Result<(Vec, Vec)> { + // For now, implement a simple single-leaf approach as placeholder + // This will be replaced with the full BKD tree building algorithm + log::warn!("Using simplified BKD tree builder (full implementation pending)"); + + let num_points = self.points.len(); + let leaf_size = self.options.leaf_size as usize; + + // Calculate bounding box for all points + let mut min_x = f64::MAX; + let mut min_y = f64::MAX; + let mut max_x = f64::MIN; + let mut max_y = f64::MIN; + + for (x, y, _) in &self.points { + min_x = min_x.min(*x); + min_y = min_y.min(*y); + max_x = max_x.max(*x); + max_y = max_y.max(*y); + } + + // Sort points by X coordinate for better spatial locality + self.points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + + // Split points into leaf batches + let mut leaf_batches = Vec::new(); + let mut nodes = Vec::new(); + + for (leaf_id, chunk) in self.points.chunks(leaf_size).enumerate() { + // Create leaf batch + let x_vals: Vec = chunk.iter().map(|(x, _, _)| *x).collect(); + let y_vals: Vec = chunk.iter().map(|(_, y, _)| *y).collect(); + let row_ids: Vec = chunk.iter().map(|(_, _, r)| *r).collect(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Float64Array::from(x_vals)) as ArrayRef, + Arc::new(Float64Array::from(y_vals)) as ArrayRef, + Arc::new(UInt64Array::from(row_ids)) as ArrayRef, + ], + )?; + + leaf_batches.push(batch); + + // Create leaf node (simplified - one node per leaf) + nodes.push(BKDNode { + bounds: [min_x, min_y, max_x, max_y], + split_dim: 0, + split_value: 0.0, + left_child: None, + right_child: None, + leaf_id: Some(leaf_id as u32), + }); + } + + // If we have multiple leaves, create a simple root node + // (This is a placeholder - full tree construction will be more sophisticated) + if nodes.len() > 1 { + // For now, just use the first leaf node as root + // Full implementation will build proper hierarchical tree + } + + log::debug!( + "Built simplified BKD tree: {} points -> {} leaves", + num_points, + leaf_batches.len() + ); + + Ok((nodes, leaf_batches)) + } +} + +#[derive(Debug, Default)] +pub struct GeoIndexPlugin; + +impl GeoIndexPlugin { + async fn train_geo_index( + batches_source: SendableRecordBatchStream, + index_store: &dyn IndexStore, + options: Option, + ) -> Result<()> { + let value_type = batches_source.schema().field(0).data_type().clone(); + + let mut builder = GeoIndexBuilder::try_new(options.unwrap_or_default(), value_type)?; + + builder.train(batches_source).await?; + + builder.write_index(index_store).await?; + Ok(()) + } +} + +pub struct GeoIndexTrainingRequest { + pub params: GeoIndexBuilderParams, + pub criteria: TrainingCriteria, +} + +impl GeoIndexTrainingRequest { + pub fn new(params: GeoIndexBuilderParams) -> Self { + Self { + params, + criteria: TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), + } + } +} + +impl TrainingRequest for GeoIndexTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[async_trait] +impl ScalarIndexPlugin for GeoIndexPlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + // Check that the field is a Struct type + if !matches!(field.data_type(), DataType::Struct(_)) { + return Err(Error::InvalidInput { + source: "A geo index can only be created on a Struct field.".into(), + location: location!(), + }); + } + + // Check for GeoArrow metadata + let is_geoarrow = field + .metadata() + .get("ARROW:extension:name") + .map(|name| name.starts_with("geoarrow.")) + .unwrap_or(false); + + if !is_geoarrow { + return Err(Error::InvalidInput { + source: format!( + "Geo index requires GeoArrow metadata on field '{}'. \ + The field must have 'ARROW:extension:name' metadata starting with 'geoarrow.'", + field.name() + ) + .into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + + Ok(Box::new(GeoIndexTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true // We do exact point-in-bbox filtering in search_leaf + } + + fn version(&self) -> u32 { + GEO_INDEX_VERSION + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(GeoQueryParser::new(index_name))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + if fragment_ids.is_some() { + return Err(Error::InvalidInput { + source: "Geo index does not support fragment training".into(), + location: location!(), + }); + } + + let request = (request as Box) + .downcast::() + .map_err(|_| Error::InvalidInput { + source: "must provide training request created by new_training_request".into(), + location: location!(), + })?; + Self::train_geo_index(data, index_store, Some(request.params)).await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&pbold::GeoIndexDetails::default()) + .unwrap(), + index_version: GEO_INDEX_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result> { + Ok(GeoIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Fields, Schema}; + use datafusion::execution::SendableRecordBatchStream; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use futures::stream; + use lance_core::cache::LanceCache; + use lance_core::utils::tempfile::TempObjDir; + use lance_core::ROW_ADDR; + use lance_io::object_store::ObjectStore; + + use crate::scalar::lance_format::LanceIndexStore; + + #[tokio::test] + async fn test_empty_geo_index() { + let tmpdir = TempObjDir::default(); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let data = arrow_array::StructArray::from(vec![]); + let row_ids = arrow_array::UInt64Array::from(Vec::::new()); + let fields: Fields = Vec::::new().into(); + let schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Struct(fields), false), + Field::new(ROW_ADDR, DataType::UInt64, false), + ])); + let data = + RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap(); + + let data_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::once(std::future::ready(Ok(data))), + )); + + GeoIndexPlugin::train_geo_index(data_stream, test_store.as_ref(), None) + .await + .unwrap(); + + // Read the index file back and check its contents + let _index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .expect("Failed to load GeoIndex"); + } +} + diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index a36e221f6a0..612eaf471dd 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -15,9 +15,9 @@ use crate::{ frag_reuse::FragReuseIndex, scalar::{ bitmap::BitmapIndexPlugin, bloomfilter::BloomFilterIndexPlugin, btree::BTreeIndexPlugin, - expression::ScalarQueryParser, inverted::InvertedIndexPlugin, json::JsonIndexPlugin, - label_list::LabelListIndexPlugin, ngram::NGramIndexPlugin, zonemap::ZoneMapIndexPlugin, - CreatedIndex, IndexStore, ScalarIndex, + expression::ScalarQueryParser, geoindex::GeoIndexPlugin, inverted::InvertedIndexPlugin, + json::JsonIndexPlugin, label_list::LabelListIndexPlugin, ngram::NGramIndexPlugin, + zonemap::ZoneMapIndexPlugin, CreatedIndex, IndexStore, ScalarIndex, }, }; @@ -201,6 +201,7 @@ impl ScalarIndexPluginRegistry { registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); + registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index 4c58375619e..898042f514c 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -65,6 +65,8 @@ impl SqlQueryBuilder { pub async fn build(self) -> lance_core::Result { let ctx = SessionContext::new(); + // Register Lance UDFs + lance_datafusion::udf::register_functions(&ctx); let row_id = self.with_row_id; let row_addr = self.with_row_addr; ctx.register_table( diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index 76f1fba3d34..c9be843dd5e 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -151,7 +151,8 @@ impl<'a> CreateIndexBuilder<'a> { | IndexType::NGram | IndexType::ZoneMap | IndexType::BloomFilter - | IndexType::LabelList, + | IndexType::LabelList + | IndexType::Geo, LANCE_SCALAR_INDEX, ) => { let base_params = ScalarIndexParams::for_builtin(self.index_type.try_into()?); diff --git a/test_geoarrow_geo_index.py b/test_geoarrow_geo_index.py new file mode 100644 index 00000000000..253c2755b01 --- /dev/null +++ b/test_geoarrow_geo_index.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +""" +Test script for GeoArrow Point geo index functionality in Lance. + +This script tests: +1. Creating GeoArrow Point data +2. Writing to Lance dataset +3. Creating a geo index on GeoArrow Point column +4. Querying with spatial filters +5. Verifying the geo index is used +""" + + +import numpy as np +import pyarrow as pa +import lance +import os +import shutil +from geoarrow.pyarrow import point + + +def main(): + print("šŸŒ Testing GeoArrow Point Geo Index in Lance") + print("=" * 50) + + + # Clean slate + dataset_path = "/Users/jay.narale/work/Uber/geo_index_test" + if os.path.exists(dataset_path): + shutil.rmtree(dataset_path) + print(f"āœ… Cleaned up existing dataset: {dataset_path}") + + + # Step 1: Create GeoArrow Point data + print("\nšŸ”µ Step 1: Creating GeoArrow Point data") + lat_np = np.array([37.7749, 34.0522, 40.7128], dtype="float64") # SF, LA, NYC + lng_np = np.array([-122.4194, -118.2437, -74.0060], dtype="float64") + + + start_location = point().from_geobuffers(None, lng_np, lat_np) + + + table = pa.table({ + "id": [1, 2, 3], + "city": ["San Francisco", "Los Angeles", "New York"], + "start_location": start_location, + "population": [883305, 3898747, 8336817] + }) + + + print("āœ… Created GeoArrow Point data") + print("šŸ“Š Table schema:") + print(table.schema) + print(f"šŸ“ Point column type: {table.schema.field('start_location').type}") + print(f"šŸ“ Point column metadata: {table.schema.field('start_location').metadata}") + + + # Step 2: Write to Lance dataset + print("\nšŸ”µ Step 2: Writing to Lance dataset") + try: + geo_ds = lance.write_dataset(table, dataset_path) + print("āœ… Successfully wrote GeoArrow data to Lance dataset") + + + # Verify data was written correctly + loaded_table = geo_ds.to_table() + print(f"šŸ“Š Dataset has {len(loaded_table)} rows") + print("šŸ“Š Dataset schema:") + print(loaded_table.schema) + + + except Exception as e: + print(f"āŒ Failed to write dataset: {e}") + return + + + # Step 3: Create geo index + print("\nšŸ”µ Step 3: Creating geo index on GeoArrow Point column") + try: + geo_ds.create_scalar_index(column="start_location", index_type="GEO") + print("āœ… Successfully created geo index") + + + # Check what indexes exist + indexes = geo_ds.list_indices() + print("šŸ“Š Available indexes:") + for idx in indexes: + print(f" - {idx}") + + + except Exception as e: + print(f"āŒ Failed to create geo index: {e}") + return + + + # Step 4: Test st_intersects spatial query with broad bbox (both cities) + print("\nšŸ”µ Step 4: Testing st_intersects spatial query with broad bbox (both cities)") + + + + + # First, run EXPLAIN ANALYZE to see the execution plan + explain_sql = """ + EXPLAIN ANALYZE SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) + """ + + + print("\nšŸ“‹ Running EXPLAIN ANALYZE...") + explain_query = geo_ds.sql(explain_sql).build() + explain_result = explain_query.to_batch_records() + + + if explain_result: + explain_table = pa.Table.from_batches(explain_result) + print("šŸ” EXPLAIN ANALYZE Result:") + print(f"Schema: {explain_table.schema}") + print(f"Rows: {len(explain_table)}") + + + # Print the execution plan + for i in range(len(explain_table)): + for j, column in enumerate(explain_table.columns): + col_name = explain_table.schema.field(j).name + value = column.to_pylist()[i] + print(f"šŸ“Š {col_name}: {value}") + + + # Check if geo index was used + if len(explain_table) > 0: + # Column 1 contains the actual plan, column 0 is just the plan type + plan_text = str(explain_table.column(1).to_pylist()[0]) + if "ScalarIndexQuery" in plan_text or "start_location_idx" in plan_text: + print("āœ… šŸŒ GEO INDEX WAS USED!") + if "start_location_idx" in plan_text: + print("āœ… šŸŒ Found geo index reference: start_location_idx") + if "ST_Intersects" in plan_text: + print("āœ… šŸŒ Spatial query detected: ST_Intersects") + # Extract performance metrics + import re + if "output_rows=" in plan_text: + rows_match = re.search(r'output_rows=(\d+)', plan_text) + if rows_match: + print(f"āœ… šŸŒ Index returned {rows_match.group(1)} rows") + if "search_time=" in plan_text: + time_match = re.search(r'search_time=([^,\]]+)', plan_text) + if time_match: + print(f"āœ… šŸŒ Index search time: {time_match.group(1)}") + else: + print("āš ļø Geo index was not detected in execution plan") + print(f"šŸ“‹ Full plan: {plan_text}") + + + # Now run the actual query and get complete results + print("\nšŸ“‹ Running actual query...") + actual_sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) + """ + query = geo_ds.sql(actual_sql).build() + result = query.to_batch_records() + + + if result: + table = pa.Table.from_batches(result) + print("āœ… Query Results:") + print(f"šŸ“Š Schema: {table.schema}") + print(f"šŸ“Š Number of rows: {len(table)}") + + + # Print complete results + for i in range(len(table)): + row_data = {} + for j, column in enumerate(table.columns): + col_name = table.schema.field(j).name + value = column.to_pylist()[i] + row_data[col_name] = value + print(f"šŸ“ Row {i}: {row_data}") + + + cities = table.column('city').to_pylist() + print(f"\nāœ… Found {len(cities)} cities with broad bbox: {cities}") + assert len(cities) == 2, f"Expected 2 cities, got {len(cities)}" + assert 'San Francisco' in cities, "Expected San Francisco in results" + assert 'Los Angeles' in cities, "Expected Los Angeles in results" + else: + print("āš ļø No results returned") + + + # Step 4b: Test with tight bbox (only San Francisco) + print("\nšŸ”µ Step 4b: Testing st_intersects with tight bbox (only San Francisco)") + # SF is at (-122.4194, 37.7749), so use a tight box around it + tight_sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-123, 37, -122, 38)) + """ + tight_query = geo_ds.sql(tight_sql).build() + tight_result = tight_query.to_batch_records() + + if tight_result: + tight_table = pa.Table.from_batches(tight_result) + print("āœ… Query Results:") + print(f"šŸ“Š Number of rows: {len(tight_table)}") + + for i in range(len(tight_table)): + row_data = {} + for j, column in enumerate(tight_table.columns): + col_name = tight_table.schema.field(j).name + value = column.to_pylist()[i] + row_data[col_name] = value + print(f"šŸ“ Row {i}: {row_data}") + + cities = tight_table.column('city').to_pylist() + print(f"\nāœ… Found {len(cities)} city with tight bbox: {cities}") + assert len(cities) == 1, f"Expected 1 city, got {len(cities)}" + assert cities[0] == 'San Francisco', f"Expected San Francisco, got {cities[0]}" + else: + print("āš ļø No results returned") + + + + + + + # Step 5: Check index files + print("\nšŸ”µ Step 5: Verifying index files") + try: + import glob + index_files = glob.glob(f"{dataset_path}/_indices/*") + print(f"šŸ“‚ Index directories: {len(index_files)}") + + + for idx_dir in index_files: + files = glob.glob(f"{idx_dir}/*") + print(f"šŸ“‚ Files in {idx_dir}:") + for f in files: + file_size = os.path.getsize(f) + print(f" - {os.path.basename(f)} ({file_size} bytes)") + + + except Exception as e: + print(f"āŒ Failed to check index files: {e}") + + + print("\nšŸŽ‰ Test completed!") + print("=" * 50) + + +if __name__ == "__main__": + main() \ No newline at end of file From 55abfb20ad8b7e234ac36e933926a7e5bc7a6a00 Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Wed, 15 Oct 2025 16:02:56 -0700 Subject: [PATCH 2/7] Cursor generated bkd tree poc --- rust/lance-index/src/scalar.rs | 1 + rust/lance-index/src/scalar/bkd.rs | 373 ++++++++++++++++++++++++ rust/lance-index/src/scalar/geoindex.rs | 325 +++++---------------- test_geoarrow_geo_index.py | 90 ++++-- 4 files changed, 517 insertions(+), 272 deletions(-) create mode 100644 rust/lance-index/src/scalar/bkd.rs diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 2d78847a3e1..a9ae40cabf0 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -28,6 +28,7 @@ use crate::metrics::MetricsCollector; use crate::scalar::registry::TrainingCriteria; use crate::{Index, IndexParams, IndexType}; +pub mod bkd; pub mod bitmap; pub mod bloomfilter; pub mod btree; diff --git a/rust/lance-index/src/scalar/bkd.rs b/rust/lance-index/src/scalar/bkd.rs new file mode 100644 index 00000000000..c99422288cd --- /dev/null +++ b/rust/lance-index/src/scalar/bkd.rs @@ -0,0 +1,373 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! BKD Tree (Block K-Dimensional Tree) implementation +//! +//! A BKD tree is a spatial index structure for efficiently indexing and querying +//! multi-dimensional points. It's similar to a KD-tree but optimized for disk storage +//! by grouping multiple points into leaf blocks. +//! +//! ## Algorithm +//! +//! Based on Lucene's BKD tree implementation: +//! - Recursively splits points by alternating dimensions (X, Y) +//! - Splits at median to create balanced tree +//! - Groups points into leaves of configurable size +//! - Stores tree structure separately from leaf data for lazy loading + +use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt64Array}; +use arrow_array::cast::AsArray; +use arrow_schema::{DataType, Field, Schema}; +use deepsize::DeepSizeOf; +use lance_core::{Result, ROW_ID}; +use std::sync::Arc; + +/// BKD Tree node representing either an inner node or a leaf +#[derive(Debug, Clone, DeepSizeOf)] +pub struct BKDNode { + /// Bounding box: [min_x, min_y, max_x, max_y] + pub bounds: [f64; 4], + /// Split dimension: 0=X, 1=Y + pub split_dim: u8, + /// Split value along split_dim + pub split_value: f64, + /// Left child node index (for inner nodes) + pub left_child: Option, + /// Right child node index (for inner nodes) + pub right_child: Option, + /// Leaf ID - corresponds to leaf_{id}.lance file (for leaf nodes) + pub leaf_id: Option, +} + +/// In-memory BKD tree structure for efficient spatial queries +#[derive(Debug, DeepSizeOf)] +pub struct BKDTreeLookup { + pub nodes: Vec, + pub root_id: u32, + pub num_leaves: u32, +} + +impl BKDTreeLookup { + pub fn new(nodes: Vec, root_id: u32, num_leaves: u32) -> Self { + Self { + nodes, + root_id, + num_leaves, + } + } + + /// Find all leaf IDs that intersect with the query bounding box + pub fn find_intersecting_leaves(&self, query_bbox: [f64; 4]) -> Vec { + let mut leaf_ids = Vec::new(); + let mut stack = vec![self.root_id]; + let mut nodes_visited = 0; + + while let Some(node_id) = stack.pop() { + if node_id as usize >= self.nodes.len() { + continue; + } + + let node = &self.nodes[node_id as usize]; + nodes_visited += 1; + + // Check if node's bounding box intersects with query bbox + if !bboxes_intersect(&node.bounds, &query_bbox) { + continue; + } + + // If this is a leaf node, add its leaf_id + if let Some(leaf_id) = node.leaf_id { + println!( + " šŸƒ Found intersecting leaf_{}.lance: bounds {:?}", + leaf_id, + node.bounds + ); + leaf_ids.push(leaf_id); + } else { + // Inner node - traverse children + if let Some(left) = node.left_child { + stack.push(left); + } + if let Some(right) = node.right_child { + stack.push(right); + } + } + } + + println!( + "🌲 Tree traversal: visited {} nodes, found {} intersecting leaves", + nodes_visited, + leaf_ids.len() + ); + + leaf_ids + } + + /// Deserialize from RecordBatch + pub fn from_record_batch(batch: RecordBatch) -> Result { + if batch.num_rows() == 0 { + return Ok(Self::new(vec![], 0, 0)); + } + + let min_x = batch + .column(0) + .as_primitive::(); + let min_y = batch + .column(1) + .as_primitive::(); + let max_x = batch + .column(2) + .as_primitive::(); + let max_y = batch + .column(3) + .as_primitive::(); + let split_dim = batch + .column(4) + .as_primitive::(); + let split_value = batch + .column(5) + .as_primitive::(); + let left_child = batch + .column(6) + .as_primitive::(); + let right_child = batch + .column(7) + .as_primitive::(); + let leaf_id = batch + .column(8) + .as_primitive::(); + + let mut nodes = Vec::with_capacity(batch.num_rows()); + let mut num_leaves = 0; + + for i in 0..batch.num_rows() { + let leaf_id_val = if leaf_id.is_null(i) { + None + } else { + num_leaves += 1; + Some(leaf_id.value(i)) + }; + + nodes.push(BKDNode { + bounds: [ + min_x.value(i), + min_y.value(i), + max_x.value(i), + max_y.value(i), + ], + split_dim: split_dim.value(i), + split_value: split_value.value(i), + left_child: if left_child.is_null(i) { + None + } else { + Some(left_child.value(i)) + }, + right_child: if right_child.is_null(i) { + None + } else { + Some(right_child.value(i)) + }, + leaf_id: leaf_id_val, + }); + } + + Ok(Self::new(nodes, 0, num_leaves)) + } +} + +/// Check if two bounding boxes intersect +pub fn bboxes_intersect(bbox1: &[f64; 4], bbox2: &[f64; 4]) -> bool { + // bbox format: [min_x, min_y, max_x, max_y] + !(bbox1[2] < bbox2[0] || bbox1[0] > bbox2[2] || bbox1[3] < bbox2[1] || bbox1[1] > bbox2[3]) +} + +/// Check if a point is within a bounding box +pub fn point_in_bbox(x: f64, y: f64, bbox: &[f64; 4]) -> bool { + x >= bbox[0] && x <= bbox[2] && y >= bbox[1] && y <= bbox[3] +} + +/// BKD Tree builder following Lucene's bulk-loading algorithm +pub struct BKDTreeBuilder { + leaf_size: usize, +} + +impl BKDTreeBuilder { + pub fn new(leaf_size: usize) -> Self { + Self { leaf_size } + } + + /// Build a BKD tree from points + /// Returns (tree_nodes, leaf_batches) + pub fn build(&self, points: &mut [(f64, f64, u64)]) -> Result<(Vec, Vec)> { + if points.is_empty() { + return Ok((vec![], vec![])); + } + + println!( + "\nšŸ—ļø Building BKD tree for {} points with leaf size {}", + points.len(), + self.leaf_size + ); + + // Log first few points for debugging + println!("šŸ“ First 5 points:"); + for i in 0..std::cmp::min(5, points.len()) { + println!(" Point {}: x={}, y={}, row_id={}", i, points[i].0, points[i].1, points[i].2); + } + + let mut leaf_counter = 0u32; + let mut all_nodes = Vec::new(); + let mut all_leaf_batches = Vec::new(); + + self.build_recursive( + points, + 0, // depth + &mut leaf_counter, + &mut all_nodes, + &mut all_leaf_batches, + )?; + + println!( + "āœ… Built BKD tree: {} nodes ({} leaves)\n", + all_nodes.len(), + leaf_counter + ); + + Ok((all_nodes, all_leaf_batches)) + } + + /// Recursively build BKD tree following Lucene's algorithm + fn build_recursive( + &self, + points: &mut [(f64, f64, u64)], + depth: u32, + leaf_counter: &mut u32, + all_nodes: &mut Vec, + all_leaf_batches: &mut Vec, + ) -> Result { + // Base case: create leaf node + if points.len() <= self.leaf_size { + let node_id = all_nodes.len() as u32; + let leaf_id = *leaf_counter; + *leaf_counter += 1; + + // Calculate bounding box for this leaf + let (min_x, min_y, max_x, max_y) = calculate_bounds(points); + + // Create leaf batch + let leaf_batch = create_leaf_batch(points)?; + all_leaf_batches.push(leaf_batch); + + // Create leaf node + all_nodes.push(BKDNode { + bounds: [min_x, min_y, max_x, max_y], + split_dim: 0, + split_value: 0.0, + left_child: None, + right_child: None, + leaf_id: Some(leaf_id), + }); + + return Ok(node_id); + } + + // Recursive case: split and build subtrees + let split_dim = (depth % 2) as u8; // Alternate between X (0) and Y (1) + + // Sort points by the split dimension + if split_dim == 0 { + points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + } else { + points.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + } + + // Split at median + let mid = points.len() / 2; + let split_value = if split_dim == 0 { + points[mid].0 + } else { + points[mid].1 + }; + + // Calculate bounds for this node (before splitting the slice) + let (min_x, min_y, max_x, max_y) = calculate_bounds(points); + + // Reserve space for this inner node + let node_id = all_nodes.len() as u32; + all_nodes.push(BKDNode { + bounds: [min_x, min_y, max_x, max_y], + split_dim, + split_value, + left_child: None, + right_child: None, + leaf_id: None, + }); + + // Recursively build left and right subtrees + let (left_points, right_points) = points.split_at_mut(mid); + + let left_child_id = self.build_recursive( + left_points, + depth + 1, + leaf_counter, + all_nodes, + all_leaf_batches, + )?; + + let right_child_id = self.build_recursive( + right_points, + depth + 1, + leaf_counter, + all_nodes, + all_leaf_batches, + )?; + + // Update the inner node with child pointers + all_nodes[node_id as usize].left_child = Some(left_child_id); + all_nodes[node_id as usize].right_child = Some(right_child_id); + + Ok(node_id) + } +} + +/// Calculate bounding box for a set of points +fn calculate_bounds(points: &[(f64, f64, u64)]) -> (f64, f64, f64, f64) { + let mut min_x = f64::MAX; + let mut min_y = f64::MAX; + let mut max_x = f64::MIN; + let mut max_y = f64::MIN; + + for (x, y, _) in points { + min_x = min_x.min(*x); + min_y = min_y.min(*y); + max_x = max_x.max(*x); + max_y = max_y.max(*y); + } + + (min_x, min_y, max_x, max_y) +} + +/// Create a leaf batch from points +fn create_leaf_batch(points: &[(f64, f64, u64)]) -> Result { + let x_vals: Vec = points.iter().map(|(x, _, _)| *x).collect(); + let y_vals: Vec = points.iter().map(|(_, y, _)| *y).collect(); + let row_ids: Vec = points.iter().map(|(_, _, r)| *r).collect(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Float64Array::from(x_vals)) as ArrayRef, + Arc::new(Float64Array::from(y_vals)) as ArrayRef, + Arc::new(UInt64Array::from(row_ids)) as ArrayRef, + ], + )?; + + Ok(batch) +} + diff --git a/rust/lance-index/src/scalar/geoindex.rs b/rust/lance-index/src/scalar/geoindex.rs index f8b4d7411cf..49fcd83832a 100644 --- a/rust/lance-index/src/scalar/geoindex.rs +++ b/rust/lance-index/src/scalar/geoindex.rs @@ -19,6 +19,7 @@ //! use crate::pbold; +use crate::scalar::bkd::{BKDTreeBuilder, BKDTreeLookup, point_in_bbox}; use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser}; use crate::scalar::registry::{ ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest, @@ -32,7 +33,7 @@ use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; use lance_core::{ROW_ADDR, ROW_ID}; use serde::{Deserialize, Serialize}; -use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt64Array, UInt8Array}; +use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt8Array}; use arrow_array::cast::AsArray; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; @@ -50,142 +51,13 @@ use roaring::RoaringBitmap; use snafu::location; const BKD_TREE_FILENAME: &str = "bkd_tree.lance"; -const BKD_LEAVES_FILENAME: &str = "bkd_leaves.lance"; +const LEAF_FILENAME_PREFIX: &str = "leaf_"; const GEO_INDEX_VERSION: u32 = 0; const LEAF_SIZE_META_KEY: &str = "leaf_size"; -const DEFAULT_LEAF_SIZE: u32 = 4096; - -/// BKD Tree node representing either an inner node or a leaf -#[derive(Debug, Clone, DeepSizeOf)] -struct BKDNode { - /// Bounding box: [min_x, min_y, max_x, max_y] - bounds: [f64; 4], - /// Split dimension: 0=X, 1=Y - split_dim: u8, - /// Split value along split_dim - split_value: f64, - /// Left child node index (for inner nodes) - left_child: Option, - /// Right child node index (for inner nodes) - right_child: Option, - /// Leaf ID in bkd_leaves.lance (for leaf nodes) - leaf_id: Option, -} +const DEFAULT_LEAF_SIZE: u32 = 100; -/// In-memory BKD tree structure for efficient spatial queries -#[derive(Debug, DeepSizeOf)] -struct BKDTreeLookup { - nodes: Vec, - root_id: u32, - num_leaves: u32, -} - -impl BKDTreeLookup { - fn new(nodes: Vec, root_id: u32, num_leaves: u32) -> Self { - Self { - nodes, - root_id, - num_leaves, - } - } - - /// Find all leaf IDs that intersect with the query bounding box - fn find_intersecting_leaves(&self, query_bbox: [f64; 4]) -> Vec { - let mut leaf_ids = Vec::new(); - let mut stack = vec![self.root_id]; - - while let Some(node_id) = stack.pop() { - if node_id as usize >= self.nodes.len() { - continue; - } - - let node = &self.nodes[node_id as usize]; - - // Check if node's bounding box intersects with query bbox - if !bboxes_intersect(&node.bounds, &query_bbox) { - continue; - } - - // If this is a leaf node, add its leaf_id - if let Some(leaf_id) = node.leaf_id { - leaf_ids.push(leaf_id); - } else { - // Inner node - traverse children - if let Some(left) = node.left_child { - stack.push(left); - } - if let Some(right) = node.right_child { - stack.push(right); - } - } - } - - leaf_ids - } - - /// Deserialize from RecordBatch - fn from_record_batch(batch: RecordBatch) -> Result { - if batch.num_rows() == 0 { - return Ok(Self::new(vec![], 0, 0)); - } - - let min_x = batch.column(0).as_primitive::(); - let min_y = batch.column(1).as_primitive::(); - let max_x = batch.column(2).as_primitive::(); - let max_y = batch.column(3).as_primitive::(); - let split_dim = batch.column(4).as_primitive::(); - let split_value = batch.column(5).as_primitive::(); - let left_child = batch.column(6).as_primitive::(); - let right_child = batch.column(7).as_primitive::(); - let leaf_id = batch.column(8).as_primitive::(); - - let mut nodes = Vec::with_capacity(batch.num_rows()); - let mut num_leaves = 0; - - for i in 0..batch.num_rows() { - let leaf_id_val = if leaf_id.is_null(i) { - None - } else { - num_leaves += 1; - Some(leaf_id.value(i)) - }; - - nodes.push(BKDNode { - bounds: [ - min_x.value(i), - min_y.value(i), - max_x.value(i), - max_y.value(i), - ], - split_dim: split_dim.value(i), - split_value: split_value.value(i), - left_child: if left_child.is_null(i) { - None - } else { - Some(left_child.value(i)) - }, - right_child: if right_child.is_null(i) { - None - } else { - Some(right_child.value(i)) - }, - leaf_id: leaf_id_val, - }); - } - - Ok(Self::new(nodes, 0, num_leaves)) - } -} - -/// Check if two bounding boxes intersect -fn bboxes_intersect(bbox1: &[f64; 4], bbox2: &[f64; 4]) -> bool { - // bbox format: [min_x, min_y, max_x, max_y] - !(bbox1[2] < bbox2[0] || bbox1[0] > bbox2[2] || bbox1[3] < bbox2[1] || bbox1[1] > bbox2[3]) -} - -/// Check if a point is within a bounding box -fn point_in_bbox(x: f64, y: f64, bbox: &[f64; 4]) -> bool { - x >= bbox[0] && x <= bbox[2] && y >= bbox[1] && y <= bbox[3] +fn leaf_filename(leaf_id: u32) -> String { + format!("{}{}.lance", LEAF_FILENAME_PREFIX, leaf_id) } /// Lazy reader for BKD leaf file @@ -319,24 +191,25 @@ impl GeoIndex { })) } - /// Load a specific leaf from storage + /// Load a specific leaf from storage (from leaf_{id}.lance file) async fn load_leaf( &self, leaf_id: u32, - index_reader: LazyIndexReader, metrics: &dyn MetricsCollector, ) -> Result { // Check cache first let cache_key = BKDLeafKey { leaf_id }; + let store = self.store.clone(); let cached = self .index_cache .get_or_insert_with_key(cache_key, move || async move { metrics.record_part_load(); - let reader = index_reader.get().await?; - let batch = reader - .read_record_batch(leaf_id as u64, self.leaf_size as u64) - .await?; + let filename = leaf_filename(leaf_id); + let reader = store.open_index_file(&filename).await?; + // Read the entire leaf file + let num_rows = reader.num_rows(); + let batch = reader.read_range(0..num_rows, None).await?; Ok(CachedLeafData::new(batch)) }) .await?; @@ -349,10 +222,16 @@ impl GeoIndex { &self, leaf_id: u32, query_bbox: [f64; 4], - index_reader: LazyIndexReader, metrics: &dyn MetricsCollector, ) -> Result { - let leaf_data = self.load_leaf(leaf_id, index_reader, metrics).await?; + let leaf_data = self.load_leaf(leaf_id, metrics).await?; + + println!( + "šŸ” Searching leaf {} with {} points, query_bbox: {:?}", + leaf_id, + leaf_data.num_rows(), + query_bbox + ); // Filter points within this leaf let mut row_ids = RowIdTreeMap::new(); @@ -366,16 +245,38 @@ impl GeoIndex { .column(2) .as_primitive::(); + let mut matched_count = 0; for i in 0..leaf_data.num_rows() { let x = x_array.value(i); let y = y_array.value(i); + let row_id = row_id_array.value(i); + + // Debug: dump all points in leaf 16 to find San Francisco + if leaf_id == 16 && i < 10 { + println!(" šŸ”Ž Leaf 16 point {}: ({}, {}) row_id={}", i, x, y, row_id); + } if point_in_bbox(x, y, &query_bbox) { - let row_id = row_id_array.value(i); row_ids.insert(row_id); + matched_count += 1; + + // Log first few matches for debugging + if matched_count <= 3 { + println!( + " āœ… Match {}: point({}, {}) -> row_id {}", + matched_count, x, y, row_id + ); + } } } + println!( + "šŸ“Š Leaf {} matched {} out of {} points", + leaf_id, + matched_count, + leaf_data.num_rows() + ); + Ok(row_ids) } } @@ -434,26 +335,26 @@ impl ScalarIndex for GeoIndex { GeoQuery::Intersects(min_x, min_y, max_x, max_y) => { let query_bbox = [*min_x, *min_y, *max_x, *max_y]; - log::debug!( - "Geo index search: st_intersects with bbox({}, {}, {}, {})", + println!( + "\nšŸ” Geo index search: st_intersects with bbox({}, {}, {}, {})", min_x, min_y, max_x, max_y ); // Step 1: Find intersecting leaves using in-memory tree traversal let leaf_ids = self.bkd_tree.find_intersecting_leaves(query_bbox); - log::debug!( - "BKD tree found {} intersecting leaves", - leaf_ids.len() + println!( + "šŸ“Š BKD tree traversal found {} intersecting leaves out of {} total leaves", + leaf_ids.len(), + self.bkd_tree.num_leaves ); // Step 2: Lazy-load and filter each leaf let mut all_row_ids = RowIdTreeMap::new(); - let lazy_reader = LazyIndexReader::new(self.store.clone(), BKD_LEAVES_FILENAME); - for leaf_id in leaf_ids { + for leaf_id in &leaf_ids { let leaf_row_ids = self - .search_leaf(leaf_id, query_bbox, lazy_reader.clone(), metrics) + .search_leaf(*leaf_id, query_bbox, metrics) .await?; // Collect row IDs from the leaf and add them to the result set let row_ids: Option> = leaf_row_ids.row_ids() @@ -463,9 +364,10 @@ impl ScalarIndex for GeoIndex { } } - log::debug!( - "Geo index returning {:?} row IDs", - all_row_ids.len() + println!( + "āœ… Geo index searched {} leaves and returning {} row IDs\n", + leaf_ids.len(), + all_row_ids.len().unwrap_or(0) ); // We return Exact because we already filtered points in search_leaf @@ -621,19 +523,24 @@ impl GeoIndexBuilder { )])) .await?; - // Write leaf data + // Write each leaf to a separate file let leaf_schema = Arc::new(Schema::new(vec![ Field::new("x", DataType::Float64, false), Field::new("y", DataType::Float64, false), Field::new(ROW_ID, DataType::UInt64, false), ])); - let mut leaf_file = index_store - .new_index_file(BKD_LEAVES_FILENAME, leaf_schema) - .await?; - for leaf_batch in leaf_batches { - leaf_file.write_record_batch(leaf_batch).await?; + + println!("šŸ“ Writing {} leaf files", leaf_batches.len()); + for (leaf_id, leaf_batch) in leaf_batches.iter().enumerate() { + let filename = leaf_filename(leaf_id as u32); + println!(" Writing {}: {} rows", filename, leaf_batch.num_rows()); + let mut leaf_file = index_store + .new_index_file(&filename, leaf_schema.clone()) + .await?; + leaf_file.write_record_batch(leaf_batch.clone()).await?; + leaf_file.finish().await?; } - leaf_file.finish().await?; + println!("āœ… Finished writing {} leaf files\n", leaf_batches.len()); log::debug!( "Wrote BKD tree with {} nodes", @@ -664,23 +571,12 @@ impl GeoIndexBuilder { tree_file.write_record_batch(empty_batch).await?; tree_file.finish().await?; - // Write empty leaves file - let leaf_schema = Arc::new(Schema::new(vec![ - Field::new("x", DataType::Float64, false), - Field::new("y", DataType::Float64, false), - Field::new(ROW_ID, DataType::UInt64, false), - ])); - let empty_leaves = RecordBatch::new_empty(leaf_schema); - let mut leaf_file = index_store - .new_index_file(BKD_LEAVES_FILENAME, empty_leaves.schema()) - .await?; - leaf_file.write_record_batch(empty_leaves).await?; - leaf_file.finish().await?; + // No leaf files needed for empty index Ok(()) } - fn serialize_tree_nodes(&self, nodes: &[BKDNode]) -> Result { + fn serialize_tree_nodes(&self, nodes: &[crate::scalar::bkd::BKDNode]) -> Result { let mut min_x_vals = Vec::with_capacity(nodes.len()); let mut min_y_vals = Vec::with_capacity(nodes.len()); let mut max_x_vals = Vec::with_capacity(nodes.len()); @@ -730,83 +626,10 @@ impl GeoIndexBuilder { Ok(RecordBatch::try_new(schema, columns)?) } - // Build BKD tree using Lucene's algorithm (deferred for next step) - fn build_bkd_tree(&mut self) -> Result<(Vec, Vec)> { - // For now, implement a simple single-leaf approach as placeholder - // This will be replaced with the full BKD tree building algorithm - log::warn!("Using simplified BKD tree builder (full implementation pending)"); - - let num_points = self.points.len(); - let leaf_size = self.options.leaf_size as usize; - - // Calculate bounding box for all points - let mut min_x = f64::MAX; - let mut min_y = f64::MAX; - let mut max_x = f64::MIN; - let mut max_y = f64::MIN; - - for (x, y, _) in &self.points { - min_x = min_x.min(*x); - min_y = min_y.min(*y); - max_x = max_x.max(*x); - max_y = max_y.max(*y); - } - - // Sort points by X coordinate for better spatial locality - self.points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); - - // Split points into leaf batches - let mut leaf_batches = Vec::new(); - let mut nodes = Vec::new(); - - for (leaf_id, chunk) in self.points.chunks(leaf_size).enumerate() { - // Create leaf batch - let x_vals: Vec = chunk.iter().map(|(x, _, _)| *x).collect(); - let y_vals: Vec = chunk.iter().map(|(_, y, _)| *y).collect(); - let row_ids: Vec = chunk.iter().map(|(_, _, r)| *r).collect(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("x", DataType::Float64, false), - Field::new("y", DataType::Float64, false), - Field::new(ROW_ID, DataType::UInt64, false), - ])); - - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Float64Array::from(x_vals)) as ArrayRef, - Arc::new(Float64Array::from(y_vals)) as ArrayRef, - Arc::new(UInt64Array::from(row_ids)) as ArrayRef, - ], - )?; - - leaf_batches.push(batch); - - // Create leaf node (simplified - one node per leaf) - nodes.push(BKDNode { - bounds: [min_x, min_y, max_x, max_y], - split_dim: 0, - split_value: 0.0, - left_child: None, - right_child: None, - leaf_id: Some(leaf_id as u32), - }); - } - - // If we have multiple leaves, create a simple root node - // (This is a placeholder - full tree construction will be more sophisticated) - if nodes.len() > 1 { - // For now, just use the first leaf node as root - // Full implementation will build proper hierarchical tree - } - - log::debug!( - "Built simplified BKD tree: {} points -> {} leaves", - num_points, - leaf_batches.len() - ); - - Ok((nodes, leaf_batches)) + // Build BKD tree using the BKDTreeBuilder + fn build_bkd_tree(&mut self) -> Result<(Vec, Vec)> { + let builder = BKDTreeBuilder::new(self.options.leaf_size as usize); + builder.build(&mut self.points) } } diff --git a/test_geoarrow_geo_index.py b/test_geoarrow_geo_index.py index 253c2755b01..67538437a22 100644 --- a/test_geoarrow_geo_index.py +++ b/test_geoarrow_geo_index.py @@ -16,8 +16,13 @@ import lance import os import shutil +import logging from geoarrow.pyarrow import point +# Enable Rust logging +os.environ['RUST_LOG'] = 'lance_index=debug' +logging.basicConfig(level=logging.DEBUG) + def main(): print("šŸŒ Testing GeoArrow Point Geo Index in Lance") @@ -31,28 +36,54 @@ def main(): print(f"āœ… Cleaned up existing dataset: {dataset_path}") - # Step 1: Create GeoArrow Point data - print("\nšŸ”µ Step 1: Creating GeoArrow Point data") - lat_np = np.array([37.7749, 34.0522, 40.7128], dtype="float64") # SF, LA, NYC - lng_np = np.array([-122.4194, -118.2437, -74.0060], dtype="float64") - - - start_location = point().from_geobuffers(None, lng_np, lat_np) - - + # Step 1: Create GeoArrow Point data with enough points to test tree structure + print("\nšŸ”µ Step 1: Creating GeoArrow Point data (5000+ points)") + + # Generate random points across the US + # US bounding box approximately: lng [-125, -65], lat [25, 50] + np.random.seed(42) # For reproducibility + num_points = 5000 + + lng_vals = np.random.uniform(-125, -65, num_points) + lat_vals = np.random.uniform(25, 50, num_points) + + # Add some known cities at the beginning for testing + known_cities = [ + {"id": 1, "city": "San Francisco", "lng": -122.4194, "lat": 37.7749, "population": 883305}, + {"id": 2, "city": "Los Angeles", "lng": -118.2437, "lat": 34.0522, "population": 3898747}, + {"id": 3, "city": "New York", "lng": -74.0060, "lat": 40.7128, "population": 8336817}, + {"id": 4, "city": "Chicago", "lng": -87.6298, "lat": 41.8781, "population": 2746388}, + {"id": 5, "city": "Houston", "lng": -95.3698, "lat": 29.7604, "population": 2304580}, + ] + + # Replace first 5 points with known cities + for i, city in enumerate(known_cities): + lng_vals[i] = city["lng"] + lat_vals[i] = city["lat"] + + start_location = point().from_geobuffers(None, lng_vals, lat_vals) + + # Create IDs and city names + ids = list(range(1, num_points + 1)) + cities = [known_cities[i]["city"] if i < len(known_cities) else f"Point_{i+1}" + for i in range(num_points)] + populations = [known_cities[i]["population"] if i < len(known_cities) else np.random.randint(10000, 1000000) + for i in range(num_points)] + table = pa.table({ - "id": [1, 2, 3], - "city": ["San Francisco", "Los Angeles", "New York"], + "id": ids, + "city": cities, "start_location": start_location, - "population": [883305, 3898747, 8336817] + "population": populations }) - print("āœ… Created GeoArrow Point data") + print(f"āœ… Created GeoArrow Point data with {num_points} points") print("šŸ“Š Table schema:") print(table.schema) print(f"šŸ“ Point column type: {table.schema.field('start_location').type}") print(f"šŸ“ Point column metadata: {table.schema.field('start_location').metadata}") + print(f"šŸ“ Known cities: {[c['city'] for c in known_cities]}") # Step 2: Write to Lance dataset @@ -170,21 +201,28 @@ def main(): print(f"šŸ“Š Number of rows: {len(table)}") - # Print complete results - for i in range(len(table)): + # Print first few results + max_rows_to_print = min(10, len(table)) + for i in range(max_rows_to_print): row_data = {} for j, column in enumerate(table.columns): col_name = table.schema.field(j).name value = column.to_pylist()[i] row_data[col_name] = value print(f"šŸ“ Row {i}: {row_data}") + if len(table) > max_rows_to_print: + print(f"... and {len(table) - max_rows_to_print} more rows") cities = table.column('city').to_pylist() - print(f"\nāœ… Found {len(cities)} cities with broad bbox: {cities}") - assert len(cities) == 2, f"Expected 2 cities, got {len(cities)}" + print(f"\nāœ… Found {len(cities)} results with broad bbox") + print(f"šŸ“Š Known cities in results: {[c for c in cities if c in ['San Francisco', 'Los Angeles', 'New York', 'Chicago', 'Houston']]}") + + # With 5000 random points and a broad western US bbox, we should get hundreds/thousands of results + assert len(cities) > 100, f"Expected many results (>100) from broad bbox, got {len(cities)}" assert 'San Francisco' in cities, "Expected San Francisco in results" assert 'Los Angeles' in cities, "Expected Los Angeles in results" + print(f"āœ… Verified SF and LA are in the {len(cities)} results") else: print("āš ļø No results returned") @@ -205,18 +243,28 @@ def main(): print("āœ… Query Results:") print(f"šŸ“Š Number of rows: {len(tight_table)}") - for i in range(len(tight_table)): + # Print first few results + max_rows_to_print = min(10, len(tight_table)) + for i in range(max_rows_to_print): row_data = {} for j, column in enumerate(tight_table.columns): col_name = tight_table.schema.field(j).name value = column.to_pylist()[i] row_data[col_name] = value print(f"šŸ“ Row {i}: {row_data}") + if len(tight_table) > max_rows_to_print: + print(f"... and {len(tight_table) - max_rows_to_print} more rows") cities = tight_table.column('city').to_pylist() - print(f"\nāœ… Found {len(cities)} city with tight bbox: {cities}") - assert len(cities) == 1, f"Expected 1 city, got {len(cities)}" - assert cities[0] == 'San Francisco', f"Expected San Francisco, got {cities[0]}" + print(f"\nāœ… Found {len(cities)} results with tight bbox") + known_cities_found = [c for c in cities if c in ['San Francisco', 'Los Angeles', 'New York', 'Chicago', 'Houston']] + print(f"šŸ“Š Known cities in results: {known_cities_found}") + + # The tight bbox around SF should include SF, and might include some random points + assert 'San Francisco' in cities, "Expected San Francisco in results" + assert len(known_cities_found) == 1 and known_cities_found[0] == 'San Francisco', \ + f"Expected only San Francisco from known cities, got {known_cities_found}" + print(f"āœ… Verified only SF is in the known cities, total results: {len(cities)}") else: print("āš ļø No results returned") From 89004552aaca4a5c86d5e2e53d921afe281ca927 Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Wed, 15 Oct 2025 22:46:16 -0700 Subject: [PATCH 3/7] Some improvements --- rust/lance-index/src/scalar/bkd.rs | 766 ++++++++++++++++++++---- rust/lance-index/src/scalar/geoindex.rs | 315 +++++++--- 2 files changed, 876 insertions(+), 205 deletions(-) diff --git a/rust/lance-index/src/scalar/bkd.rs b/rust/lance-index/src/scalar/bkd.rs index c99422288cd..b885bbe6898 100644 --- a/rust/lance-index/src/scalar/bkd.rs +++ b/rust/lance-index/src/scalar/bkd.rs @@ -15,28 +15,118 @@ //! - Groups points into leaves of configurable size //! - Stores tree structure separately from leaf data for lazy loading -use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt64Array}; +use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt64Array, UInt32Array, UInt8Array}; use arrow_array::cast::AsArray; use arrow_schema::{DataType, Field, Schema}; use deepsize::DeepSizeOf; use lance_core::{Result, ROW_ID}; use std::sync::Arc; +use snafu::location; + +// Schema field names +const NODE_ID: &str = "node_id"; +const MIN_X: &str = "min_x"; +const MIN_Y: &str = "min_y"; +const MAX_X: &str = "max_x"; +const MAX_Y: &str = "max_y"; +const SPLIT_DIM: &str = "split_dim"; +const SPLIT_VALUE: &str = "split_value"; +const LEFT_CHILD: &str = "left_child"; +const RIGHT_CHILD: &str = "right_child"; +const FILE_ID: &str = "file_id"; +const ROW_OFFSET: &str = "row_offset"; +const NUM_ROWS: &str = "num_rows"; + +/// Schema for inner node metadata +pub fn inner_node_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new(NODE_ID, DataType::UInt32, false), + Field::new(MIN_X, DataType::Float64, false), + Field::new(MIN_Y, DataType::Float64, false), + Field::new(MAX_X, DataType::Float64, false), + Field::new(MAX_Y, DataType::Float64, false), + Field::new(SPLIT_DIM, DataType::UInt8, false), + Field::new(SPLIT_VALUE, DataType::Float64, false), + Field::new(LEFT_CHILD, DataType::UInt32, false), + Field::new(RIGHT_CHILD, DataType::UInt32, false), + ])) +} + +/// Schema for leaf node metadata +pub fn leaf_node_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new(NODE_ID, DataType::UInt32, false), + Field::new(MIN_X, DataType::Float64, false), + Field::new(MIN_Y, DataType::Float64, false), + Field::new(MAX_X, DataType::Float64, false), + Field::new(MAX_Y, DataType::Float64, false), + Field::new(FILE_ID, DataType::UInt32, false), + Field::new(ROW_OFFSET, DataType::UInt64, false), + Field::new(NUM_ROWS, DataType::UInt64, false), + ])) +} -/// BKD Tree node representing either an inner node or a leaf +/// BKD Tree node - either an inner node (with children) or a leaf node (with data location) #[derive(Debug, Clone, DeepSizeOf)] -pub struct BKDNode { +pub enum BKDNode { + Inner(BKDInnerNode), + Leaf(BKDLeafNode), +} + +impl BKDNode { + pub fn bounds(&self) -> [f64; 4] { + match self { + BKDNode::Inner(inner) => inner.bounds, + BKDNode::Leaf(leaf) => leaf.bounds, + } + } + + pub fn is_leaf(&self) -> bool { + matches!(self, BKDNode::Leaf(_)) + } + + pub fn as_inner(&self) -> Option<&BKDInnerNode> { + match self { + BKDNode::Inner(inner) => Some(inner), + _ => None, + } + } + + pub fn as_leaf(&self) -> Option<&BKDLeafNode> { + match self { + BKDNode::Leaf(leaf) => Some(leaf), + _ => None, + } + } +} + +/// Inner node in BKD tree - contains split information and child pointers +#[derive(Debug, Clone, DeepSizeOf)] +pub struct BKDInnerNode { /// Bounding box: [min_x, min_y, max_x, max_y] pub bounds: [f64; 4], /// Split dimension: 0=X, 1=Y pub split_dim: u8, /// Split value along split_dim pub split_value: f64, - /// Left child node index (for inner nodes) - pub left_child: Option, - /// Right child node index (for inner nodes) - pub right_child: Option, - /// Leaf ID - corresponds to leaf_{id}.lance file (for leaf nodes) - pub leaf_id: Option, + /// Left child node index + pub left_child: u32, + /// Right child node index + pub right_child: u32, +} + +/// Leaf node in BKD tree - contains location of actual point data +#[derive(Debug, Clone, DeepSizeOf)] +pub struct BKDLeafNode { + /// Bounding box: [min_x, min_y, max_x, max_y] + pub bounds: [f64; 4], + /// Which leaf group file this leaf is in + /// Corresponds to leaf_group_{file_id}.lance + pub file_id: u32, + /// Row offset within the leaf group file + pub row_offset: u64, + /// Number of rows in this leaf batch + pub num_rows: u64, } /// In-memory BKD tree structure for efficient spatial queries @@ -56,9 +146,10 @@ impl BKDTreeLookup { } } - /// Find all leaf IDs that intersect with the query bounding box - pub fn find_intersecting_leaves(&self, query_bbox: [f64; 4]) -> Vec { - let mut leaf_ids = Vec::new(); + /// Find all leaf nodes that intersect with the query bounding box + /// Returns references to the intersecting leaf nodes + pub fn find_intersecting_leaves(&self, query_bbox: [f64; 4]) -> Result> { + let mut leaves = Vec::new(); let mut stack = vec![self.root_id]; let mut nodes_visited = 0; @@ -71,25 +162,20 @@ impl BKDTreeLookup { nodes_visited += 1; // Check if node's bounding box intersects with query bbox - if !bboxes_intersect(&node.bounds, &query_bbox) { + let intersects = bboxes_intersect(&node.bounds(), &query_bbox); + if !intersects { continue; } - // If this is a leaf node, add its leaf_id - if let Some(leaf_id) = node.leaf_id { - println!( - " šŸƒ Found intersecting leaf_{}.lance: bounds {:?}", - leaf_id, - node.bounds - ); - leaf_ids.push(leaf_id); - } else { - // Inner node - traverse children - if let Some(left) = node.left_child { - stack.push(left); + match node { + BKDNode::Leaf(leaf) => { + // Leaf node - add to results + leaves.push(leaf); } - if let Some(right) = node.right_child { - stack.push(right); + BKDNode::Inner(inner) => { + // Inner node - traverse children + stack.push(inner.left_child); + stack.push(inner.right_child); } } } @@ -97,82 +183,125 @@ impl BKDTreeLookup { println!( "🌲 Tree traversal: visited {} nodes, found {} intersecting leaves", nodes_visited, - leaf_ids.len() + leaves.len() ); - leaf_ids + Ok(leaves) } - /// Deserialize from RecordBatch - pub fn from_record_batch(batch: RecordBatch) -> Result { - if batch.num_rows() == 0 { + /// Deserialize from separate inner and leaf RecordBatches + /// Both batches include node_id field to preserve tree structure + pub fn from_record_batches(inner_batch: RecordBatch, leaf_batch: RecordBatch) -> Result { + if inner_batch.num_rows() == 0 && leaf_batch.num_rows() == 0 { return Ok(Self::new(vec![], 0, 0)); } - let min_x = batch - .column(0) - .as_primitive::(); - let min_y = batch - .column(1) - .as_primitive::(); - let max_x = batch - .column(2) - .as_primitive::(); - let max_y = batch - .column(3) - .as_primitive::(); - let split_dim = batch - .column(4) - .as_primitive::(); - let split_value = batch - .column(5) - .as_primitive::(); - let left_child = batch - .column(6) - .as_primitive::(); - let right_child = batch - .column(7) - .as_primitive::(); - let leaf_id = batch - .column(8) - .as_primitive::(); - - let mut nodes = Vec::with_capacity(batch.num_rows()); + // Helper to get column by name + let get_col = |batch: &RecordBatch, name: &str| -> Result { + batch.schema().column_with_name(name) + .map(|(idx, _)| idx) + .ok_or_else(|| lance_core::Error::Internal { + message: format!("Missing column '{}' in BKD tree batch", name), + location: location!(), + }) + }; + + // Determine total number of nodes (max node_id + 1) + let max_node_id = { + let mut max_id = 0u32; + + if inner_batch.num_rows() > 0 { + let col_idx = get_col(&inner_batch, NODE_ID)?; + let node_ids = inner_batch.column(col_idx).as_primitive::(); + for i in 0..inner_batch.num_rows() { + max_id = max_id.max(node_ids.value(i)); + } + } + + if leaf_batch.num_rows() > 0 { + let col_idx = get_col(&leaf_batch, NODE_ID)?; + let node_ids = leaf_batch.column(col_idx).as_primitive::(); + for i in 0..leaf_batch.num_rows() { + max_id = max_id.max(node_ids.value(i)); + } + } + + max_id + }; + + // Create sparse array of nodes (filled with dummy data initially) + let mut nodes = vec![ + BKDNode::Leaf(BKDLeafNode { + bounds: [0.0, 0.0, 0.0, 0.0], + file_id: 0, + row_offset: 0, + num_rows: 0, + }); + (max_node_id + 1) as usize + ]; + let mut num_leaves = 0; - - for i in 0..batch.num_rows() { - let leaf_id_val = if leaf_id.is_null(i) { - None - } else { + + // Fill in inner nodes + if inner_batch.num_rows() > 0 { + let node_ids = inner_batch.column(get_col(&inner_batch, NODE_ID)?).as_primitive::(); + let min_x = inner_batch.column(get_col(&inner_batch, MIN_X)?).as_primitive::(); + let min_y = inner_batch.column(get_col(&inner_batch, MIN_Y)?).as_primitive::(); + let max_x = inner_batch.column(get_col(&inner_batch, MAX_X)?).as_primitive::(); + let max_y = inner_batch.column(get_col(&inner_batch, MAX_Y)?).as_primitive::(); + let split_dim = inner_batch.column(get_col(&inner_batch, SPLIT_DIM)?).as_primitive::(); + let split_value = inner_batch.column(get_col(&inner_batch, SPLIT_VALUE)?).as_primitive::(); + let left_child = inner_batch.column(get_col(&inner_batch, LEFT_CHILD)?).as_primitive::(); + let right_child = inner_batch.column(get_col(&inner_batch, RIGHT_CHILD)?).as_primitive::(); + + for i in 0..inner_batch.num_rows() { + let node_id = node_ids.value(i) as usize; + nodes[node_id] = BKDNode::Inner(BKDInnerNode { + bounds: [ + min_x.value(i), + min_y.value(i), + max_x.value(i), + max_y.value(i), + ], + split_dim: split_dim.value(i), + split_value: split_value.value(i), + left_child: left_child.value(i), + right_child: right_child.value(i), + }); + } + } + + // Fill in leaf nodes + if leaf_batch.num_rows() > 0 { + let node_ids = leaf_batch.column(get_col(&leaf_batch, NODE_ID)?).as_primitive::(); + let min_x = leaf_batch.column(get_col(&leaf_batch, MIN_X)?).as_primitive::(); + let min_y = leaf_batch.column(get_col(&leaf_batch, MIN_Y)?).as_primitive::(); + let max_x = leaf_batch.column(get_col(&leaf_batch, MAX_X)?).as_primitive::(); + let max_y = leaf_batch.column(get_col(&leaf_batch, MAX_Y)?).as_primitive::(); + let file_id = leaf_batch.column(get_col(&leaf_batch, FILE_ID)?).as_primitive::(); + let row_offset = leaf_batch.column(get_col(&leaf_batch, ROW_OFFSET)?).as_primitive::(); + let num_rows = leaf_batch.column(get_col(&leaf_batch, NUM_ROWS)?).as_primitive::(); + + for i in 0..leaf_batch.num_rows() { + let node_id = node_ids.value(i) as usize; + nodes[node_id] = BKDNode::Leaf(BKDLeafNode { + bounds: [ + min_x.value(i), + min_y.value(i), + max_x.value(i), + max_y.value(i), + ], + file_id: file_id.value(i), + row_offset: row_offset.value(i), + num_rows: num_rows.value(i), + }); num_leaves += 1; - Some(leaf_id.value(i)) - }; - - nodes.push(BKDNode { - bounds: [ - min_x.value(i), - min_y.value(i), - max_x.value(i), - max_y.value(i), - ], - split_dim: split_dim.value(i), - split_value: split_value.value(i), - left_child: if left_child.is_null(i) { - None - } else { - Some(left_child.value(i)) - }, - right_child: if right_child.is_null(i) { - None - } else { - Some(right_child.value(i)) - }, - leaf_id: leaf_id_val, - }); + } } - + Ok(Self::new(nodes, 0, num_leaves)) } + } /// Check if two bounding boxes intersect @@ -198,7 +327,7 @@ impl BKDTreeBuilder { /// Build a BKD tree from points /// Returns (tree_nodes, leaf_batches) - pub fn build(&self, points: &mut [(f64, f64, u64)]) -> Result<(Vec, Vec)> { + pub fn build(&self, points: &mut [(f64, f64, u64)], batches_per_file: u32) -> Result<(Vec, Vec)> { if points.is_empty() { return Ok((vec![], vec![])); } @@ -223,10 +352,44 @@ impl BKDTreeBuilder { points, 0, // depth &mut leaf_counter, + batches_per_file, &mut all_nodes, &mut all_leaf_batches, )?; + // Post-process: Update leaf nodes with correct row offsets + // Leaves are created in order during build_recursive, so we can + // calculate cumulative offsets based on actual batch sizes + let mut current_file_id = 0u32; + let mut row_offset_in_file = 0u64; + let mut batches_in_current_file = 0u32; + let mut leaf_idx = 0; + + for node in all_nodes.iter_mut() { + if let BKDNode::Leaf(leaf) = node { + if leaf_idx < all_leaf_batches.len() { + let batch_num_rows = all_leaf_batches[leaf_idx].num_rows() as u64; + + // Check if we need to move to next file + if batches_in_current_file >= batches_per_file && batches_per_file > 0 { + current_file_id += 1; + row_offset_in_file = 0; + batches_in_current_file = 0; + } + + // Update leaf with correct metadata + leaf.file_id = current_file_id; + leaf.row_offset = row_offset_in_file; + leaf.num_rows = batch_num_rows; + + // Advance for next leaf + row_offset_in_file += batch_num_rows; + batches_in_current_file += 1; + leaf_idx += 1; + } + } + } + println!( "āœ… Built BKD tree: {} nodes ({} leaves)\n", all_nodes.len(), @@ -242,6 +405,7 @@ impl BKDTreeBuilder { points: &mut [(f64, f64, u64)], depth: u32, leaf_counter: &mut u32, + batches_per_file: u32, all_nodes: &mut Vec, all_leaf_batches: &mut Vec, ) -> Result { @@ -256,17 +420,22 @@ impl BKDTreeBuilder { // Create leaf batch let leaf_batch = create_leaf_batch(points)?; + let num_rows = leaf_batch.num_rows() as u64; all_leaf_batches.push(leaf_batch); - // Create leaf node - all_nodes.push(BKDNode { + // Create leaf node (file_id, row_offset will be set in post-processing) + all_nodes.push(BKDNode::Leaf(BKDLeafNode { bounds: [min_x, min_y, max_x, max_y], - split_dim: 0, - split_value: 0.0, - left_child: None, - right_child: None, - leaf_id: Some(leaf_id), - }); + file_id: 0, // Will be updated in post-processing + row_offset: 0, // Will be updated in post-processing + num_rows, + })); + + // Debug: Check if SF (row_id=0) is in this leaf + if points.iter().any(|(_, _, rid)| *rid == 0) { + println!("šŸŽÆ SF (row_id=0) in leaf node_id={}, leaf_id={}, num_rows={}, bounds=[{}, {}, {}, {}]", + node_id, leaf_id, num_rows, min_x, min_y, max_x, max_y); + } return Ok(node_id); } @@ -274,6 +443,11 @@ impl BKDTreeBuilder { // Recursive case: split and build subtrees let split_dim = (depth % 2) as u8; // Alternate between X (0) and Y (1) + // TODO: Replace with radix selection for O(n) median finding (Lucene's approach) + // Current: O(n log n) sorting at each level = O(n log² n) total + // Target: O(n) radix select at each level = O(n log n) total + // See: https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java + // Sort points by the split dimension if split_dim == 0 { points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); @@ -291,17 +465,25 @@ impl BKDTreeBuilder { // Calculate bounds for this node (before splitting the slice) let (min_x, min_y, max_x, max_y) = calculate_bounds(points); + + // Debug: Log first inner node to verify bounds + if all_nodes.is_empty() { + println!("šŸ” Root node bounds: [{}, {}, {}, {}]", min_x, min_y, max_x, max_y); + println!(" Split on dim {} at value {}", split_dim, split_value); + println!(" Contains SF (-122.4194, 37.7749)? x_ok={}, y_ok={}", + min_x <= -122.4194 && -122.4194 <= max_x, + min_y <= 37.7749 && 37.7749 <= max_y); + } - // Reserve space for this inner node + // Reserve space for this inner node (placeholder - we'll update it after building children) let node_id = all_nodes.len() as u32; - all_nodes.push(BKDNode { + all_nodes.push(BKDNode::Inner(BKDInnerNode { bounds: [min_x, min_y, max_x, max_y], split_dim, split_value, - left_child: None, - right_child: None, - leaf_id: None, - }); + left_child: 0, // Placeholder + right_child: 0, // Placeholder + })); // Recursively build left and right subtrees let (left_points, right_points) = points.split_at_mut(mid); @@ -310,6 +492,7 @@ impl BKDTreeBuilder { left_points, depth + 1, leaf_counter, + batches_per_file, all_nodes, all_leaf_batches, )?; @@ -318,13 +501,16 @@ impl BKDTreeBuilder { right_points, depth + 1, leaf_counter, + batches_per_file, all_nodes, all_leaf_batches, )?; - // Update the inner node with child pointers - all_nodes[node_id as usize].left_child = Some(left_child_id); - all_nodes[node_id as usize].right_child = Some(right_child_id); + // Update the inner node with actual child pointers + if let BKDNode::Inner(inner) = &mut all_nodes[node_id as usize] { + inner.left_child = left_child_id; + inner.right_child = right_child_id; + } Ok(node_id) } @@ -371,3 +557,365 @@ fn create_leaf_batch(points: &[(f64, f64, u64)]) -> Result { Ok(batch) } +// #[cfg(test)] +// mod tests { +// use super::*; +// use arrow_array::{UInt32Array, UInt8Array}; + +// /// Helper to serialize nodes (mirrors logic from geoindex.rs) +// fn serialize_nodes(nodes: &[BKDNode]) -> Result { +// let mut min_x_vals = Vec::with_capacity(nodes.len()); +// let mut min_y_vals = Vec::with_capacity(nodes.len()); +// let mut max_x_vals = Vec::with_capacity(nodes.len()); +// let mut max_y_vals = Vec::with_capacity(nodes.len()); +// let mut split_dim_vals = Vec::with_capacity(nodes.len()); +// let mut split_value_vals = Vec::with_capacity(nodes.len()); +// let mut left_child_vals = Vec::with_capacity(nodes.len()); +// let mut right_child_vals = Vec::with_capacity(nodes.len()); +// let mut leaf_id_vals = Vec::with_capacity(nodes.len()); + +// for node in nodes { +// min_x_vals.push(node.bounds[0]); +// min_y_vals.push(node.bounds[1]); +// max_x_vals.push(node.bounds[2]); +// max_y_vals.push(node.bounds[3]); +// split_dim_vals.push(node.split_dim); +// split_value_vals.push(node.split_value); +// left_child_vals.push(node.left_child); +// right_child_vals.push(node.right_child); +// leaf_id_vals.push(node.leaf_id); +// } + +// let schema = Arc::new(Schema::new(vec![ +// Field::new("min_x", DataType::Float64, false), +// Field::new("min_y", DataType::Float64, false), +// Field::new("max_x", DataType::Float64, false), +// Field::new("max_y", DataType::Float64, false), +// Field::new("split_dim", DataType::UInt8, false), +// Field::new("split_value", DataType::Float64, false), +// Field::new("left_child", DataType::UInt32, true), +// Field::new("right_child", DataType::UInt32, true), +// Field::new("leaf_id", DataType::UInt32, true), +// ])); + +// let columns: Vec = vec![ +// Arc::new(Float64Array::from(min_x_vals)), +// Arc::new(Float64Array::from(min_y_vals)), +// Arc::new(Float64Array::from(max_x_vals)), +// Arc::new(Float64Array::from(max_y_vals)), +// Arc::new(UInt8Array::from(split_dim_vals)), +// Arc::new(Float64Array::from(split_value_vals)), +// Arc::new(UInt32Array::from(left_child_vals)), +// Arc::new(UInt32Array::from(right_child_vals)), +// Arc::new(UInt32Array::from(leaf_id_vals)), +// ]; + +// Ok(RecordBatch::try_new(schema, columns)?) +// } + +// #[test] +// fn test_empty_tree_roundtrip() { +// // Create empty tree +// let tree = BKDTreeLookup::new(vec![], 0, 0); + +// // Serialize +// let batch = serialize_nodes(&tree.nodes).unwrap(); +// assert_eq!(batch.num_rows(), 0); + +// // Deserialize +// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); + +// // Verify +// assert_eq!(deserialized.nodes.len(), 0); +// assert_eq!(deserialized.num_leaves, 0); +// assert_eq!(deserialized.root_id, 0); +// } + +// #[test] +// fn test_single_leaf_roundtrip() { +// // Create single leaf node +// let nodes = vec![BKDNode { +// bounds: [1.0, 2.0, 3.0, 4.0], +// split_dim: 0, +// split_value: 0.0, +// left_child: None, +// right_child: None, +// leaf_id: Some(0), +// }]; + +// let tree = BKDTreeLookup::new(nodes.clone(), 0, 1); + +// // Serialize +// let batch = serialize_nodes(&tree.nodes).unwrap(); +// assert_eq!(batch.num_rows(), 1); + +// // Deserialize +// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); + +// // Verify structure +// assert_eq!(deserialized.nodes.len(), 1); +// assert_eq!(deserialized.num_leaves, 1); +// assert_eq!(deserialized.root_id, 0); + +// // Verify node fields +// let node = &deserialized.nodes[0]; +// assert_eq!(node.bounds, [1.0, 2.0, 3.0, 4.0]); +// assert_eq!(node.split_dim, 0); +// assert_eq!(node.split_value, 0.0); +// assert_eq!(node.left_child, None); +// assert_eq!(node.right_child, None); +// assert_eq!(node.leaf_id, Some(0)); +// } + +// #[test] +// fn test_simple_tree_roundtrip() { +// // Create tree: 1 root (inner) + 2 leaves +// let nodes = vec![ +// // Node 0: Root (inner node) +// BKDNode { +// bounds: [0.0, 0.0, 10.0, 10.0], +// split_dim: 0, // Split on X +// split_value: 5.0, +// left_child: Some(1), +// right_child: Some(2), +// leaf_id: None, +// }, +// // Node 1: Left leaf +// BKDNode { +// bounds: [0.0, 0.0, 5.0, 10.0], +// split_dim: 0, +// split_value: 0.0, +// left_child: None, +// right_child: None, +// leaf_id: Some(0), +// }, +// // Node 2: Right leaf +// BKDNode { +// bounds: [5.0, 0.0, 10.0, 10.0], +// split_dim: 0, +// split_value: 0.0, +// left_child: None, +// right_child: None, +// leaf_id: Some(1), +// }, +// ]; + +// let tree = BKDTreeLookup::new(nodes.clone(), 0, 2); + +// // Serialize +// let batch = serialize_nodes(&tree.nodes).unwrap(); +// assert_eq!(batch.num_rows(), 3); + +// // Deserialize +// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); + +// // Verify structure +// assert_eq!(deserialized.nodes.len(), 3); +// assert_eq!(deserialized.num_leaves, 2); +// assert_eq!(deserialized.root_id, 0); + +// // Verify root node (inner) +// let root = &deserialized.nodes[0]; +// assert_eq!(root.bounds, [0.0, 0.0, 10.0, 10.0]); +// assert_eq!(root.split_dim, 0); +// assert_eq!(root.split_value, 5.0); +// assert_eq!(root.left_child, Some(1)); +// assert_eq!(root.right_child, Some(2)); +// assert_eq!(root.leaf_id, None); + +// // Verify left leaf +// let left = &deserialized.nodes[1]; +// assert_eq!(left.bounds, [0.0, 0.0, 5.0, 10.0]); +// assert_eq!(left.leaf_id, Some(0)); +// assert_eq!(left.left_child, None); +// assert_eq!(left.right_child, None); + +// // Verify right leaf +// let right = &deserialized.nodes[2]; +// assert_eq!(right.bounds, [5.0, 0.0, 10.0, 10.0]); +// assert_eq!(right.leaf_id, Some(1)); +// assert_eq!(right.left_child, None); +// assert_eq!(right.right_child, None); +// } + +// #[test] +// fn test_multi_level_tree_roundtrip() { +// // Build a real tree from points +// let mut points = vec![ +// (1.0, 1.0, 0), +// (2.0, 2.0, 1), +// (3.0, 3.0, 2), +// (4.0, 4.0, 3), +// (5.0, 5.0, 4), +// (6.0, 6.0, 5), +// (7.0, 7.0, 6), +// (8.0, 8.0, 7), +// (9.0, 9.0, 8), +// (10.0, 10.0, 9), +// ]; + +// let builder = BKDTreeBuilder::new(3); // leaf_size = 3 +// let (nodes, _leaf_batches) = builder.build(&mut points).unwrap(); + +// let original_tree = BKDTreeLookup::new(nodes.clone(), 0, 4); + +// // Serialize +// let batch = serialize_nodes(&original_tree.nodes).unwrap(); + +// // Deserialize +// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); + +// // Verify counts +// assert_eq!(deserialized.nodes.len(), original_tree.nodes.len()); +// assert_eq!(deserialized.num_leaves, original_tree.num_leaves); +// assert_eq!(deserialized.root_id, 0); + +// // Verify each node +// for (i, (orig, deser)) in original_tree.nodes.iter().zip(deserialized.nodes.iter()).enumerate() { +// assert_eq!(deser.bounds, orig.bounds, "Node {} bounds mismatch", i); +// assert_eq!(deser.split_dim, orig.split_dim, "Node {} split_dim mismatch", i); +// assert_eq!(deser.split_value, orig.split_value, "Node {} split_value mismatch", i); +// assert_eq!(deser.left_child, orig.left_child, "Node {} left_child mismatch", i); +// assert_eq!(deser.right_child, orig.right_child, "Node {} right_child mismatch", i); +// assert_eq!(deser.leaf_id, orig.leaf_id, "Node {} leaf_id mismatch", i); +// } +// } + +// #[test] +// fn test_field_precision() { +// // Test edge values and precision +// let nodes = vec![ +// BKDNode { +// bounds: [f64::MIN, f64::MAX, -1e-10, 1e10], +// split_dim: 1, +// split_value: std::f64::consts::PI, +// left_child: None, +// right_child: None, +// leaf_id: Some(42), +// }, +// ]; + +// let tree = BKDTreeLookup::new(nodes.clone(), 0, 1); + +// // Serialize and deserialize +// let batch = serialize_nodes(&tree.nodes).unwrap(); +// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); + +// // Verify exact values (no precision loss) +// let node = &deserialized.nodes[0]; +// assert_eq!(node.bounds[0], f64::MIN); +// assert_eq!(node.bounds[1], f64::MAX); +// assert_eq!(node.bounds[2], -1e-10); +// assert_eq!(node.bounds[3], 1e10); +// assert_eq!(node.split_value, std::f64::consts::PI); +// assert_eq!(node.leaf_id, Some(42)); +// } + +// #[test] +// fn test_tree_structure_validation() { +// // Create tree with invalid structure (child pointer out of bounds) +// // This should be caught during traversal, not deserialization +// let nodes = vec![ +// BKDNode { +// bounds: [0.0, 0.0, 10.0, 10.0], +// split_dim: 0, +// split_value: 5.0, +// left_child: Some(1), +// right_child: Some(2), +// leaf_id: None, +// }, +// BKDNode { +// bounds: [0.0, 0.0, 5.0, 10.0], +// split_dim: 0, +// split_value: 0.0, +// left_child: None, +// right_child: None, +// leaf_id: Some(0), +// }, +// ]; + +// let tree = BKDTreeLookup::new(nodes.clone(), 0, 1); + +// // Serialize and deserialize +// let batch = serialize_nodes(&tree.nodes).unwrap(); +// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); + +// // Deserialization succeeds +// assert_eq!(deserialized.nodes.len(), 2); + +// // But traversal would fail because right_child=2 is out of bounds +// // This is expected - validation happens at query time +// let query_bbox = [0.0, 0.0, 10.0, 10.0]; +// let leaves = deserialized.find_intersecting_leaves(query_bbox); + +// // Should only find the valid left leaf +// assert_eq!(leaves.len(), 1); +// assert_eq!(leaves[0], 0); +// } + +// #[test] +// fn test_nullable_fields() { +// // Test that nullable fields work correctly +// let nodes = vec![ +// // Inner node: has children, no leaf_id +// BKDNode { +// bounds: [0.0, 0.0, 10.0, 10.0], +// split_dim: 0, +// split_value: 5.0, +// left_child: Some(1), +// right_child: Some(2), +// leaf_id: None, // NULL +// }, +// // Leaf node: has leaf_id, no children +// BKDNode { +// bounds: [0.0, 0.0, 5.0, 10.0], +// split_dim: 0, +// split_value: 0.0, +// left_child: None, // NULL +// right_child: None, // NULL +// leaf_id: Some(0), +// }, +// // Another leaf +// BKDNode { +// bounds: [5.0, 0.0, 10.0, 10.0], +// split_dim: 0, +// split_value: 0.0, +// left_child: None, +// right_child: None, +// leaf_id: Some(1), +// }, +// ]; + +// let tree = BKDTreeLookup::new(nodes.clone(), 0, 2); + +// // Serialize +// let batch = serialize_nodes(&tree.nodes).unwrap(); + +// // Check that nulls are properly represented +// let left_child_col = batch.column(6).as_primitive::(); +// let right_child_col = batch.column(7).as_primitive::(); +// let leaf_id_col = batch.column(8).as_primitive::(); + +// // Row 0 (inner): children NOT null, leaf_id IS null +// assert!(!left_child_col.is_null(0)); +// assert!(!right_child_col.is_null(0)); +// assert!(leaf_id_col.is_null(0)); + +// // Row 1 (leaf): children ARE null, leaf_id NOT null +// assert!(left_child_col.is_null(1)); +// assert!(right_child_col.is_null(1)); +// assert!(!leaf_id_col.is_null(1)); + +// // Deserialize and verify +// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); + +// assert_eq!(deserialized.nodes[0].left_child, Some(1)); +// assert_eq!(deserialized.nodes[0].right_child, Some(2)); +// assert_eq!(deserialized.nodes[0].leaf_id, None); + +// assert_eq!(deserialized.nodes[1].left_child, None); +// assert_eq!(deserialized.nodes[1].right_child, None); +// assert_eq!(deserialized.nodes[1].leaf_id, Some(0)); +// } +// } + diff --git a/rust/lance-index/src/scalar/geoindex.rs b/rust/lance-index/src/scalar/geoindex.rs index 49fcd83832a..e3119565676 100644 --- a/rust/lance-index/src/scalar/geoindex.rs +++ b/rust/lance-index/src/scalar/geoindex.rs @@ -33,7 +33,7 @@ use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; use lance_core::{ROW_ADDR, ROW_ID}; use serde::{Deserialize, Serialize}; -use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt8Array}; +use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt64Array, UInt8Array}; use arrow_array::cast::AsArray; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; @@ -50,14 +50,19 @@ use lance_core::{utils::mask::RowIdTreeMap, Error}; use roaring::RoaringBitmap; use snafu::location; -const BKD_TREE_FILENAME: &str = "bkd_tree.lance"; -const LEAF_FILENAME_PREFIX: &str = "leaf_"; +const BKD_TREE_INNER_FILENAME: &str = "bkd_tree_inner.lance"; +const BKD_TREE_LEAF_FILENAME: &str = "bkd_tree_leaf.lance"; +const LEAF_GROUP_PREFIX: &str = "leaf_group_"; +const DEFAULT_BATCHES_PER_LEAF_FILE: u32 = 5; // Default number of leaf batches per file const GEO_INDEX_VERSION: u32 = 0; const LEAF_SIZE_META_KEY: &str = "leaf_size"; -const DEFAULT_LEAF_SIZE: u32 = 100; +const BATCHES_PER_FILE_META_KEY: &str = "batches_per_file"; +const DEFAULT_LEAF_SIZE: u32 = 100; // for test +// const DEFAULT_LEAF_SIZE: u32 = 1024; // for production -fn leaf_filename(leaf_id: u32) -> String { - format!("{}{}.lance", LEAF_FILENAME_PREFIX, leaf_id) +/// Get the file name for a leaf group +fn leaf_group_filename(group_id: u32) -> String { + format!("{}{}.lance", LEAF_GROUP_PREFIX, group_id) } /// Lazy reader for BKD leaf file @@ -161,17 +166,23 @@ impl GeoIndex { where Self: Sized, { - // Load BKD tree structure (inner nodes) - let tree_file = store.open_index_file(BKD_TREE_FILENAME).await?; - let tree_data = tree_file - .read_range(0..tree_file.num_rows(), None) + // Load inner nodes + let inner_file = store.open_index_file(BKD_TREE_INNER_FILENAME).await?; + let inner_data = inner_file + .read_range(0..inner_file.num_rows(), None) .await?; - // Deserialize tree structure - let bkd_tree = BKDTreeLookup::from_record_batch(tree_data)?; + // Load leaf metadata + let leaf_file = store.open_index_file(BKD_TREE_LEAF_FILENAME).await?; + let leaf_data = leaf_file + .read_range(0..leaf_file.num_rows(), None) + .await?; + + // Deserialize tree structure from both files + let bkd_tree = BKDTreeLookup::from_record_batches(inner_data, leaf_data)?; - // Extract metadata - let schema = tree_file.schema(); + // Extract metadata from inner file + let schema = inner_file.schema(); let leaf_size = schema .metadata .get(LEAF_SIZE_META_KEY) @@ -191,25 +202,35 @@ impl GeoIndex { })) } - /// Load a specific leaf from storage (from leaf_{id}.lance file) + /// Load a specific leaf from storage async fn load_leaf( &self, - leaf_id: u32, + leaf: &crate::scalar::bkd::BKDLeafNode, metrics: &dyn MetricsCollector, ) -> Result { - // Check cache first - let cache_key = BKDLeafKey { leaf_id }; + let file_id = leaf.file_id; + let row_offset = leaf.row_offset; + let num_rows = leaf.num_rows; + + // Use (file_id, row_offset) as cache key + // Combine file_id and row_offset into a single u32 (file_id should be small) + let cache_key = BKDLeafKey { leaf_id: file_id * 100_000 + (row_offset as u32) }; let store = self.store.clone(); let cached = self .index_cache .get_or_insert_with_key(cache_key, move || async move { metrics.record_part_load(); - let filename = leaf_filename(leaf_id); + + let filename = leaf_group_filename(file_id); + + // Open the leaf group file and read the specific row range let reader = store.open_index_file(&filename).await?; - // Read the entire leaf file - let num_rows = reader.num_rows(); - let batch = reader.read_range(0..num_rows, None).await?; + let batch = reader.read_range( + row_offset as usize..(row_offset + num_rows) as usize, + None + ).await?; + Ok(CachedLeafData::new(batch)) }) .await?; @@ -220,15 +241,21 @@ impl GeoIndex { /// Search a specific leaf for points within the query bbox async fn search_leaf( &self, - leaf_id: u32, + leaf: &crate::scalar::bkd::BKDLeafNode, query_bbox: [f64; 4], metrics: &dyn MetricsCollector, ) -> Result { - let leaf_data = self.load_leaf(leaf_id, metrics).await?; + let leaf_data = self.load_leaf(leaf, metrics).await?; + let file_id = leaf.file_id; + let row_offset = leaf.row_offset; + let num_rows = leaf.num_rows; + println!( - "šŸ” Searching leaf {} with {} points, query_bbox: {:?}", - leaf_id, + "šŸ” Searching leaf (file={}, offset={}, rows={}) with {} points, query_bbox: {:?}", + file_id, + row_offset, + num_rows, leaf_data.num_rows(), query_bbox ); @@ -246,16 +273,27 @@ impl GeoIndex { .as_primitive::(); let mut matched_count = 0; + + // Debug: Check if SF (row_id=0) is in this leaf's actual data + let contains_sf = (0..leaf_data.num_rows()).any(|i| row_id_array.value(i) == 0); + if contains_sf { + println!(" šŸ”Ž Leaf (file={}, offset={}) contains SF (row_id=0)!", file_id, row_offset); + for i in 0..leaf_data.num_rows() { + if row_id_array.value(i) == 0 { + let x = x_array.value(i); + let y = y_array.value(i); + println!(" šŸŽÆ Found SF at index {}: ({}, {}) row_id=0", i, x, y); + println!(" šŸŽÆ point_in_bbox result: {}", point_in_bbox(x, y, &query_bbox)); + break; + } + } + } + for i in 0..leaf_data.num_rows() { let x = x_array.value(i); let y = y_array.value(i); let row_id = row_id_array.value(i); - // Debug: dump all points in leaf 16 to find San Francisco - if leaf_id == 16 && i < 10 { - println!(" šŸ”Ž Leaf 16 point {}: ({}, {}) row_id={}", i, x, y, row_id); - } - if point_in_bbox(x, y, &query_bbox) { row_ids.insert(row_id); matched_count += 1; @@ -271,8 +309,9 @@ impl GeoIndex { } println!( - "šŸ“Š Leaf {} matched {} out of {} points", - leaf_id, + "šŸ“Š Leaf (file={}, offset={}) matched {} out of {} points", + file_id, + row_offset, matched_count, leaf_data.num_rows() ); @@ -341,20 +380,20 @@ impl ScalarIndex for GeoIndex { ); // Step 1: Find intersecting leaves using in-memory tree traversal - let leaf_ids = self.bkd_tree.find_intersecting_leaves(query_bbox); + let leaves = self.bkd_tree.find_intersecting_leaves(query_bbox)?; println!( "šŸ“Š BKD tree traversal found {} intersecting leaves out of {} total leaves", - leaf_ids.len(), + leaves.len(), self.bkd_tree.num_leaves ); // Step 2: Lazy-load and filter each leaf let mut all_row_ids = RowIdTreeMap::new(); - for leaf_id in &leaf_ids { + for leaf_node in &leaves { let leaf_row_ids = self - .search_leaf(*leaf_id, query_bbox, metrics) + .search_leaf(leaf_node, query_bbox, metrics) .await?; // Collect row IDs from the leaf and add them to the result set let row_ids: Option> = leaf_row_ids.row_ids() @@ -366,7 +405,7 @@ impl ScalarIndex for GeoIndex { println!( "āœ… Geo index searched {} leaves and returning {} row IDs\n", - leaf_ids.len(), + leaves.len(), all_row_ids.len().unwrap_or(0) ); @@ -510,37 +549,63 @@ impl GeoIndexBuilder { // Build BKD tree let (tree_nodes, leaf_batches) = self.build_bkd_tree()?; - // Write tree structure - let tree_batch = self.serialize_tree_nodes(&tree_nodes)?; - let mut tree_file = index_store - .new_index_file(BKD_TREE_FILENAME, tree_batch.schema()) + // Write tree structure to separate inner and leaf files + let (inner_batch, leaf_metadata_batch) = self.serialize_tree_nodes(&tree_nodes)?; + + // Write inner nodes + let mut inner_file = index_store + .new_index_file(BKD_TREE_INNER_FILENAME, inner_batch.schema()) .await?; - tree_file.write_record_batch(tree_batch).await?; - tree_file + inner_file.write_record_batch(inner_batch).await?; + inner_file .finish_with_metadata(HashMap::from([( LEAF_SIZE_META_KEY.to_string(), self.options.leaf_size.to_string(), )])) .await?; - // Write each leaf to a separate file + // Write leaf metadata + let mut leaf_meta_file = index_store + .new_index_file(BKD_TREE_LEAF_FILENAME, leaf_metadata_batch.schema()) + .await?; + leaf_meta_file.write_record_batch(leaf_metadata_batch).await?; + leaf_meta_file.finish().await?; + + // Write actual leaf data grouped into files (multiple batches per file) let leaf_schema = Arc::new(Schema::new(vec![ Field::new("x", DataType::Float64, false), Field::new("y", DataType::Float64, false), Field::new(ROW_ID, DataType::UInt64, false), ])); - println!("šŸ“ Writing {} leaf files", leaf_batches.len()); - for (leaf_id, leaf_batch) in leaf_batches.iter().enumerate() { - let filename = leaf_filename(leaf_id as u32); - println!(" Writing {}: {} rows", filename, leaf_batch.num_rows()); + let num_groups = (leaf_batches.len() as u32 + DEFAULT_BATCHES_PER_LEAF_FILE - 1) / DEFAULT_BATCHES_PER_LEAF_FILE; + println!("šŸ“ Writing {} leaf batches into {} group files ({} batches per file)", + leaf_batches.len(), num_groups, DEFAULT_BATCHES_PER_LEAF_FILE); + + for group_id in 0..num_groups { + let start_idx = (group_id * DEFAULT_BATCHES_PER_LEAF_FILE) as usize; + let end_idx = ((group_id + 1) * DEFAULT_BATCHES_PER_LEAF_FILE).min(leaf_batches.len() as u32) as usize; + let group_batches = &leaf_batches[start_idx..end_idx]; + + let filename = leaf_group_filename(group_id); + println!(" Writing {}: {} batches ({} rows total)", + filename, + group_batches.len(), + group_batches.iter().map(|b| b.num_rows()).sum::()); + let mut leaf_file = index_store .new_index_file(&filename, leaf_schema.clone()) .await?; - leaf_file.write_record_batch(leaf_batch.clone()).await?; + + for (batch_idx, leaf_batch) in group_batches.iter().enumerate() { + let batch_id = leaf_file.write_record_batch(leaf_batch.clone()).await?; + println!(" Batch {}: {} rows (batch_id={})", + start_idx + batch_idx, leaf_batch.num_rows(), batch_id); + } + leaf_file.finish().await?; } - println!("āœ… Finished writing {} leaf files\n", leaf_batches.len()); + println!("āœ… Finished writing {} group files\n", num_groups); log::debug!( "Wrote BKD tree with {} nodes", @@ -551,32 +616,57 @@ impl GeoIndexBuilder { } async fn write_empty_index(&self, index_store: &dyn IndexStore) -> Result<()> { - // Write empty tree file - let tree_schema = Arc::new(Schema::new(vec![ - Field::new("min_x", DataType::Float64, false), - Field::new("min_y", DataType::Float64, false), - Field::new("max_x", DataType::Float64, false), - Field::new("max_y", DataType::Float64, false), - Field::new("split_dim", DataType::UInt8, false), - Field::new("split_value", DataType::Float64, false), - Field::new("left_child", DataType::UInt32, true), - Field::new("right_child", DataType::UInt32, true), - Field::new("leaf_id", DataType::UInt32, true), - ])); - - let empty_batch = RecordBatch::new_empty(tree_schema); - let mut tree_file = index_store - .new_index_file(BKD_TREE_FILENAME, empty_batch.schema()) + // Write empty inner node file + let inner_schema = crate::scalar::bkd::inner_node_schema(); + let empty_inner = RecordBatch::new_empty(inner_schema); + let mut inner_file = index_store + .new_index_file(BKD_TREE_INNER_FILENAME, empty_inner.schema()) + .await?; + inner_file.write_record_batch(empty_inner).await?; + inner_file.finish().await?; + + // Write empty leaf metadata file + let leaf_schema = crate::scalar::bkd::leaf_node_schema(); + let empty_leaf = RecordBatch::new_empty(leaf_schema); + let mut leaf_file = index_store + .new_index_file(BKD_TREE_LEAF_FILENAME, empty_leaf.schema()) .await?; - tree_file.write_record_batch(empty_batch).await?; - tree_file.finish().await?; + leaf_file.write_record_batch(empty_leaf).await?; + leaf_file.finish().await?; - // No leaf files needed for empty index + // No actual leaf data files needed for empty index Ok(()) } - fn serialize_tree_nodes(&self, nodes: &[crate::scalar::bkd::BKDNode]) -> Result { + // Serialize tree nodes to separate inner and leaf RecordBatches + fn serialize_tree_nodes(&self, nodes: &[crate::scalar::bkd::BKDNode]) -> Result<(RecordBatch, RecordBatch)> { + use crate::scalar::bkd::BKDNode; + + // Separate inner and leaf nodes with their indices + let mut inner_nodes = Vec::new(); + let mut leaf_nodes = Vec::new(); + + for (idx, node) in nodes.iter().enumerate() { + match node { + BKDNode::Inner(_) => inner_nodes.push((idx as u32, node)), + BKDNode::Leaf(_) => leaf_nodes.push((idx as u32, node)), + } + } + + // Serialize inner nodes + let inner_batch = Self::serialize_inner_nodes(&inner_nodes)?; + + // Serialize leaf nodes + let leaf_batch = Self::serialize_leaf_nodes(&leaf_nodes)?; + + Ok((inner_batch, leaf_batch)) + } + + fn serialize_inner_nodes(nodes: &[(u32, &crate::scalar::bkd::BKDNode)]) -> Result { + use crate::scalar::bkd::BKDNode; + + let mut node_id_vals = Vec::with_capacity(nodes.len()); let mut min_x_vals = Vec::with_capacity(nodes.len()); let mut min_y_vals = Vec::with_capacity(nodes.len()); let mut max_x_vals = Vec::with_capacity(nodes.len()); @@ -585,33 +675,25 @@ impl GeoIndexBuilder { let mut split_value_vals = Vec::with_capacity(nodes.len()); let mut left_child_vals = Vec::with_capacity(nodes.len()); let mut right_child_vals = Vec::with_capacity(nodes.len()); - let mut leaf_id_vals = Vec::with_capacity(nodes.len()); - - for node in nodes { - min_x_vals.push(node.bounds[0]); - min_y_vals.push(node.bounds[1]); - max_x_vals.push(node.bounds[2]); - max_y_vals.push(node.bounds[3]); - split_dim_vals.push(node.split_dim); - split_value_vals.push(node.split_value); - left_child_vals.push(node.left_child); - right_child_vals.push(node.right_child); - leaf_id_vals.push(node.leaf_id); + + for (idx, node) in nodes { + if let BKDNode::Inner(inner) = node { + node_id_vals.push(*idx); + min_x_vals.push(inner.bounds[0]); + min_y_vals.push(inner.bounds[1]); + max_x_vals.push(inner.bounds[2]); + max_y_vals.push(inner.bounds[3]); + split_dim_vals.push(inner.split_dim); + split_value_vals.push(inner.split_value); + left_child_vals.push(inner.left_child); + right_child_vals.push(inner.right_child); + } } - let schema = Arc::new(Schema::new(vec![ - Field::new("min_x", DataType::Float64, false), - Field::new("min_y", DataType::Float64, false), - Field::new("max_x", DataType::Float64, false), - Field::new("max_y", DataType::Float64, false), - Field::new("split_dim", DataType::UInt8, false), - Field::new("split_value", DataType::Float64, false), - Field::new("left_child", DataType::UInt32, true), - Field::new("right_child", DataType::UInt32, true), - Field::new("leaf_id", DataType::UInt32, true), - ])); + let schema = crate::scalar::bkd::inner_node_schema(); let columns: Vec = vec![ + Arc::new(UInt32Array::from(node_id_vals)), Arc::new(Float64Array::from(min_x_vals)), Arc::new(Float64Array::from(min_y_vals)), Arc::new(Float64Array::from(max_x_vals)), @@ -620,7 +702,48 @@ impl GeoIndexBuilder { Arc::new(Float64Array::from(split_value_vals)), Arc::new(UInt32Array::from(left_child_vals)), Arc::new(UInt32Array::from(right_child_vals)), - Arc::new(UInt32Array::from(leaf_id_vals)), + ]; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + fn serialize_leaf_nodes(nodes: &[(u32, &crate::scalar::bkd::BKDNode)]) -> Result { + use crate::scalar::bkd::BKDNode; + use arrow_array::UInt64Array; + + let mut node_id_vals = Vec::with_capacity(nodes.len()); + let mut min_x_vals = Vec::with_capacity(nodes.len()); + let mut min_y_vals = Vec::with_capacity(nodes.len()); + let mut max_x_vals = Vec::with_capacity(nodes.len()); + let mut max_y_vals = Vec::with_capacity(nodes.len()); + let mut file_id_vals = Vec::with_capacity(nodes.len()); + let mut row_offset_vals = Vec::with_capacity(nodes.len()); + let mut num_rows_vals = Vec::with_capacity(nodes.len()); + + for (idx, node) in nodes { + if let BKDNode::Leaf(leaf) = node { + node_id_vals.push(*idx); + min_x_vals.push(leaf.bounds[0]); + min_y_vals.push(leaf.bounds[1]); + max_x_vals.push(leaf.bounds[2]); + max_y_vals.push(leaf.bounds[3]); + file_id_vals.push(leaf.file_id); + row_offset_vals.push(leaf.row_offset); + num_rows_vals.push(leaf.num_rows); + } + } + + let schema = crate::scalar::bkd::leaf_node_schema(); + + let columns: Vec = vec![ + Arc::new(UInt32Array::from(node_id_vals)), + Arc::new(Float64Array::from(min_x_vals)), + Arc::new(Float64Array::from(min_y_vals)), + Arc::new(Float64Array::from(max_x_vals)), + Arc::new(Float64Array::from(max_y_vals)), + Arc::new(UInt32Array::from(file_id_vals)), + Arc::new(UInt64Array::from(row_offset_vals)), + Arc::new(UInt64Array::from(num_rows_vals)), ]; Ok(RecordBatch::try_new(schema, columns)?) @@ -629,7 +752,7 @@ impl GeoIndexBuilder { // Build BKD tree using the BKDTreeBuilder fn build_bkd_tree(&mut self) -> Result<(Vec, Vec)> { let builder = BKDTreeBuilder::new(self.options.leaf_size as usize); - builder.build(&mut self.points) + builder.build(&mut self.points, DEFAULT_BATCHES_PER_LEAF_FILE) } } From 54beb19b1eefd678461e87cc21718cacb43f6415 Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Thu, 16 Oct 2025 09:47:41 -0700 Subject: [PATCH 4/7] Tests that make sense --- rust/lance-index/src/scalar.rs | 3 +- rust/lance-index/src/scalar/{ => geo}/bkd.rs | 682 ++++--- rust/lance-index/src/scalar/geo/geoindex.rs | 1677 ++++++++++++++++++ rust/lance-index/src/scalar/geo/mod.rs | 15 + rust/lance-index/src/scalar/geoindex.rs | 947 ---------- rust/lance-index/src/scalar/registry.rs | 2 +- 6 files changed, 2015 insertions(+), 1311 deletions(-) rename rust/lance-index/src/scalar/{ => geo}/bkd.rs (59%) create mode 100644 rust/lance-index/src/scalar/geo/geoindex.rs create mode 100644 rust/lance-index/src/scalar/geo/mod.rs delete mode 100644 rust/lance-index/src/scalar/geoindex.rs diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index a9ae40cabf0..0af533411af 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -28,13 +28,12 @@ use crate::metrics::MetricsCollector; use crate::scalar::registry::TrainingCriteria; use crate::{Index, IndexParams, IndexType}; -pub mod bkd; pub mod bitmap; pub mod bloomfilter; pub mod btree; pub mod expression; pub mod flat; -pub mod geoindex; +pub mod geo; pub mod inverted; pub mod json; pub mod label_list; diff --git a/rust/lance-index/src/scalar/bkd.rs b/rust/lance-index/src/scalar/geo/bkd.rs similarity index 59% rename from rust/lance-index/src/scalar/bkd.rs rename to rust/lance-index/src/scalar/geo/bkd.rs index b885bbe6898..948c27afc70 100644 --- a/rust/lance-index/src/scalar/bkd.rs +++ b/rust/lance-index/src/scalar/geo/bkd.rs @@ -557,365 +557,325 @@ fn create_leaf_batch(points: &[(f64, f64, u64)]) -> Result { Ok(batch) } -// #[cfg(test)] -// mod tests { -// use super::*; -// use arrow_array::{UInt32Array, UInt8Array}; - -// /// Helper to serialize nodes (mirrors logic from geoindex.rs) -// fn serialize_nodes(nodes: &[BKDNode]) -> Result { -// let mut min_x_vals = Vec::with_capacity(nodes.len()); -// let mut min_y_vals = Vec::with_capacity(nodes.len()); -// let mut max_x_vals = Vec::with_capacity(nodes.len()); -// let mut max_y_vals = Vec::with_capacity(nodes.len()); -// let mut split_dim_vals = Vec::with_capacity(nodes.len()); -// let mut split_value_vals = Vec::with_capacity(nodes.len()); -// let mut left_child_vals = Vec::with_capacity(nodes.len()); -// let mut right_child_vals = Vec::with_capacity(nodes.len()); -// let mut leaf_id_vals = Vec::with_capacity(nodes.len()); - -// for node in nodes { -// min_x_vals.push(node.bounds[0]); -// min_y_vals.push(node.bounds[1]); -// max_x_vals.push(node.bounds[2]); -// max_y_vals.push(node.bounds[3]); -// split_dim_vals.push(node.split_dim); -// split_value_vals.push(node.split_value); -// left_child_vals.push(node.left_child); -// right_child_vals.push(node.right_child); -// leaf_id_vals.push(node.leaf_id); -// } - -// let schema = Arc::new(Schema::new(vec![ -// Field::new("min_x", DataType::Float64, false), -// Field::new("min_y", DataType::Float64, false), -// Field::new("max_x", DataType::Float64, false), -// Field::new("max_y", DataType::Float64, false), -// Field::new("split_dim", DataType::UInt8, false), -// Field::new("split_value", DataType::Float64, false), -// Field::new("left_child", DataType::UInt32, true), -// Field::new("right_child", DataType::UInt32, true), -// Field::new("leaf_id", DataType::UInt32, true), -// ])); - -// let columns: Vec = vec![ -// Arc::new(Float64Array::from(min_x_vals)), -// Arc::new(Float64Array::from(min_y_vals)), -// Arc::new(Float64Array::from(max_x_vals)), -// Arc::new(Float64Array::from(max_y_vals)), -// Arc::new(UInt8Array::from(split_dim_vals)), -// Arc::new(Float64Array::from(split_value_vals)), -// Arc::new(UInt32Array::from(left_child_vals)), -// Arc::new(UInt32Array::from(right_child_vals)), -// Arc::new(UInt32Array::from(leaf_id_vals)), -// ]; - -// Ok(RecordBatch::try_new(schema, columns)?) -// } - -// #[test] -// fn test_empty_tree_roundtrip() { -// // Create empty tree -// let tree = BKDTreeLookup::new(vec![], 0, 0); - -// // Serialize -// let batch = serialize_nodes(&tree.nodes).unwrap(); -// assert_eq!(batch.num_rows(), 0); - -// // Deserialize -// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); - -// // Verify -// assert_eq!(deserialized.nodes.len(), 0); -// assert_eq!(deserialized.num_leaves, 0); -// assert_eq!(deserialized.root_id, 0); -// } - -// #[test] -// fn test_single_leaf_roundtrip() { -// // Create single leaf node -// let nodes = vec![BKDNode { -// bounds: [1.0, 2.0, 3.0, 4.0], -// split_dim: 0, -// split_value: 0.0, -// left_child: None, -// right_child: None, -// leaf_id: Some(0), -// }]; - -// let tree = BKDTreeLookup::new(nodes.clone(), 0, 1); - -// // Serialize -// let batch = serialize_nodes(&tree.nodes).unwrap(); -// assert_eq!(batch.num_rows(), 1); - -// // Deserialize -// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); - -// // Verify structure -// assert_eq!(deserialized.nodes.len(), 1); -// assert_eq!(deserialized.num_leaves, 1); -// assert_eq!(deserialized.root_id, 0); - -// // Verify node fields -// let node = &deserialized.nodes[0]; -// assert_eq!(node.bounds, [1.0, 2.0, 3.0, 4.0]); -// assert_eq!(node.split_dim, 0); -// assert_eq!(node.split_value, 0.0); -// assert_eq!(node.left_child, None); -// assert_eq!(node.right_child, None); -// assert_eq!(node.leaf_id, Some(0)); -// } - -// #[test] -// fn test_simple_tree_roundtrip() { -// // Create tree: 1 root (inner) + 2 leaves -// let nodes = vec![ -// // Node 0: Root (inner node) -// BKDNode { -// bounds: [0.0, 0.0, 10.0, 10.0], -// split_dim: 0, // Split on X -// split_value: 5.0, -// left_child: Some(1), -// right_child: Some(2), -// leaf_id: None, -// }, -// // Node 1: Left leaf -// BKDNode { -// bounds: [0.0, 0.0, 5.0, 10.0], -// split_dim: 0, -// split_value: 0.0, -// left_child: None, -// right_child: None, -// leaf_id: Some(0), -// }, -// // Node 2: Right leaf -// BKDNode { -// bounds: [5.0, 0.0, 10.0, 10.0], -// split_dim: 0, -// split_value: 0.0, -// left_child: None, -// right_child: None, -// leaf_id: Some(1), -// }, -// ]; - -// let tree = BKDTreeLookup::new(nodes.clone(), 0, 2); - -// // Serialize -// let batch = serialize_nodes(&tree.nodes).unwrap(); -// assert_eq!(batch.num_rows(), 3); - -// // Deserialize -// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); - -// // Verify structure -// assert_eq!(deserialized.nodes.len(), 3); -// assert_eq!(deserialized.num_leaves, 2); -// assert_eq!(deserialized.root_id, 0); - -// // Verify root node (inner) -// let root = &deserialized.nodes[0]; -// assert_eq!(root.bounds, [0.0, 0.0, 10.0, 10.0]); -// assert_eq!(root.split_dim, 0); -// assert_eq!(root.split_value, 5.0); -// assert_eq!(root.left_child, Some(1)); -// assert_eq!(root.right_child, Some(2)); -// assert_eq!(root.leaf_id, None); - -// // Verify left leaf -// let left = &deserialized.nodes[1]; -// assert_eq!(left.bounds, [0.0, 0.0, 5.0, 10.0]); -// assert_eq!(left.leaf_id, Some(0)); -// assert_eq!(left.left_child, None); -// assert_eq!(left.right_child, None); - -// // Verify right leaf -// let right = &deserialized.nodes[2]; -// assert_eq!(right.bounds, [5.0, 0.0, 10.0, 10.0]); -// assert_eq!(right.leaf_id, Some(1)); -// assert_eq!(right.left_child, None); -// assert_eq!(right.right_child, None); -// } - -// #[test] -// fn test_multi_level_tree_roundtrip() { -// // Build a real tree from points -// let mut points = vec![ -// (1.0, 1.0, 0), -// (2.0, 2.0, 1), -// (3.0, 3.0, 2), -// (4.0, 4.0, 3), -// (5.0, 5.0, 4), -// (6.0, 6.0, 5), -// (7.0, 7.0, 6), -// (8.0, 8.0, 7), -// (9.0, 9.0, 8), -// (10.0, 10.0, 9), -// ]; - -// let builder = BKDTreeBuilder::new(3); // leaf_size = 3 -// let (nodes, _leaf_batches) = builder.build(&mut points).unwrap(); - -// let original_tree = BKDTreeLookup::new(nodes.clone(), 0, 4); - -// // Serialize -// let batch = serialize_nodes(&original_tree.nodes).unwrap(); - -// // Deserialize -// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); - -// // Verify counts -// assert_eq!(deserialized.nodes.len(), original_tree.nodes.len()); -// assert_eq!(deserialized.num_leaves, original_tree.num_leaves); -// assert_eq!(deserialized.root_id, 0); - -// // Verify each node -// for (i, (orig, deser)) in original_tree.nodes.iter().zip(deserialized.nodes.iter()).enumerate() { -// assert_eq!(deser.bounds, orig.bounds, "Node {} bounds mismatch", i); -// assert_eq!(deser.split_dim, orig.split_dim, "Node {} split_dim mismatch", i); -// assert_eq!(deser.split_value, orig.split_value, "Node {} split_value mismatch", i); -// assert_eq!(deser.left_child, orig.left_child, "Node {} left_child mismatch", i); -// assert_eq!(deser.right_child, orig.right_child, "Node {} right_child mismatch", i); -// assert_eq!(deser.leaf_id, orig.leaf_id, "Node {} leaf_id mismatch", i); -// } -// } - -// #[test] -// fn test_field_precision() { -// // Test edge values and precision -// let nodes = vec![ -// BKDNode { -// bounds: [f64::MIN, f64::MAX, -1e-10, 1e10], -// split_dim: 1, -// split_value: std::f64::consts::PI, -// left_child: None, -// right_child: None, -// leaf_id: Some(42), -// }, -// ]; - -// let tree = BKDTreeLookup::new(nodes.clone(), 0, 1); - -// // Serialize and deserialize -// let batch = serialize_nodes(&tree.nodes).unwrap(); -// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); - -// // Verify exact values (no precision loss) -// let node = &deserialized.nodes[0]; -// assert_eq!(node.bounds[0], f64::MIN); -// assert_eq!(node.bounds[1], f64::MAX); -// assert_eq!(node.bounds[2], -1e-10); -// assert_eq!(node.bounds[3], 1e10); -// assert_eq!(node.split_value, std::f64::consts::PI); -// assert_eq!(node.leaf_id, Some(42)); -// } - -// #[test] -// fn test_tree_structure_validation() { -// // Create tree with invalid structure (child pointer out of bounds) -// // This should be caught during traversal, not deserialization -// let nodes = vec![ -// BKDNode { -// bounds: [0.0, 0.0, 10.0, 10.0], -// split_dim: 0, -// split_value: 5.0, -// left_child: Some(1), -// right_child: Some(2), -// leaf_id: None, -// }, -// BKDNode { -// bounds: [0.0, 0.0, 5.0, 10.0], -// split_dim: 0, -// split_value: 0.0, -// left_child: None, -// right_child: None, -// leaf_id: Some(0), -// }, -// ]; - -// let tree = BKDTreeLookup::new(nodes.clone(), 0, 1); - -// // Serialize and deserialize -// let batch = serialize_nodes(&tree.nodes).unwrap(); -// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); - -// // Deserialization succeeds -// assert_eq!(deserialized.nodes.len(), 2); - -// // But traversal would fail because right_child=2 is out of bounds -// // This is expected - validation happens at query time -// let query_bbox = [0.0, 0.0, 10.0, 10.0]; -// let leaves = deserialized.find_intersecting_leaves(query_bbox); - -// // Should only find the valid left leaf -// assert_eq!(leaves.len(), 1); -// assert_eq!(leaves[0], 0); -// } - -// #[test] -// fn test_nullable_fields() { -// // Test that nullable fields work correctly -// let nodes = vec![ -// // Inner node: has children, no leaf_id -// BKDNode { -// bounds: [0.0, 0.0, 10.0, 10.0], -// split_dim: 0, -// split_value: 5.0, -// left_child: Some(1), -// right_child: Some(2), -// leaf_id: None, // NULL -// }, -// // Leaf node: has leaf_id, no children -// BKDNode { -// bounds: [0.0, 0.0, 5.0, 10.0], -// split_dim: 0, -// split_value: 0.0, -// left_child: None, // NULL -// right_child: None, // NULL -// leaf_id: Some(0), -// }, -// // Another leaf -// BKDNode { -// bounds: [5.0, 0.0, 10.0, 10.0], -// split_dim: 0, -// split_value: 0.0, -// left_child: None, -// right_child: None, -// leaf_id: Some(1), -// }, -// ]; - -// let tree = BKDTreeLookup::new(nodes.clone(), 0, 2); - -// // Serialize -// let batch = serialize_nodes(&tree.nodes).unwrap(); - -// // Check that nulls are properly represented -// let left_child_col = batch.column(6).as_primitive::(); -// let right_child_col = batch.column(7).as_primitive::(); -// let leaf_id_col = batch.column(8).as_primitive::(); - -// // Row 0 (inner): children NOT null, leaf_id IS null -// assert!(!left_child_col.is_null(0)); -// assert!(!right_child_col.is_null(0)); -// assert!(leaf_id_col.is_null(0)); - -// // Row 1 (leaf): children ARE null, leaf_id NOT null -// assert!(left_child_col.is_null(1)); -// assert!(right_child_col.is_null(1)); -// assert!(!leaf_id_col.is_null(1)); - -// // Deserialize and verify -// let deserialized = BKDTreeLookup::from_record_batch(batch).unwrap(); - -// assert_eq!(deserialized.nodes[0].left_child, Some(1)); -// assert_eq!(deserialized.nodes[0].right_child, Some(2)); -// assert_eq!(deserialized.nodes[0].leaf_id, None); - -// assert_eq!(deserialized.nodes[1].left_child, None); -// assert_eq!(deserialized.nodes[1].right_child, None); -// assert_eq!(deserialized.nodes[1].leaf_id, Some(0)); -// } -// } + + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Float64Array, UInt32Array, UInt64Array, UInt8Array}; + + /// Test serialization and deserialization of a simple BKD tree + #[test] + fn test_serialize_deserialize_simple_tree() { + // Create a simple tree: 1 inner node with 2 leaf children + let nodes = vec![ + BKDNode::Inner(BKDInnerNode { + bounds: [-10.0, -10.0, 10.0, 10.0], + split_dim: 0, + split_value: 0.0, + left_child: 1, + right_child: 2, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [-10.0, -10.0, 0.0, 10.0], + file_id: 0, + row_offset: 0, + num_rows: 100, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [0.0, -10.0, 10.0, 10.0], + file_id: 0, + row_offset: 100, + num_rows: 100, + }), + ]; + + // Separate inner and leaf nodes + let inner_nodes: Vec<(u32, &BKDNode)> = nodes + .iter() + .enumerate() + .filter(|(_, n)| matches!(n, BKDNode::Inner(_))) + .map(|(i, n)| (i as u32, n)) + .collect(); + + let leaf_nodes: Vec<(u32, &BKDNode)> = nodes + .iter() + .enumerate() + .filter(|(_, n)| matches!(n, BKDNode::Leaf(_))) + .map(|(i, n)| (i as u32, n)) + .collect(); + + // Serialize inner nodes + let mut inner_node_ids = Vec::new(); + let mut inner_min_x = Vec::new(); + let mut inner_min_y = Vec::new(); + let mut inner_max_x = Vec::new(); + let mut inner_max_y = Vec::new(); + let mut inner_split_dim = Vec::new(); + let mut inner_split_value = Vec::new(); + let mut inner_left_child = Vec::new(); + let mut inner_right_child = Vec::new(); + + for (idx, node) in &inner_nodes { + if let BKDNode::Inner(inner) = node { + inner_node_ids.push(*idx); + inner_min_x.push(inner.bounds[0]); + inner_min_y.push(inner.bounds[1]); + inner_max_x.push(inner.bounds[2]); + inner_max_y.push(inner.bounds[3]); + inner_split_dim.push(inner.split_dim); + inner_split_value.push(inner.split_value); + inner_left_child.push(inner.left_child); + inner_right_child.push(inner.right_child); + } + } + + let inner_batch = RecordBatch::try_new( + inner_node_schema(), + vec![ + Arc::new(UInt32Array::from(inner_node_ids)), + Arc::new(Float64Array::from(inner_min_x)), + Arc::new(Float64Array::from(inner_min_y)), + Arc::new(Float64Array::from(inner_max_x)), + Arc::new(Float64Array::from(inner_max_y)), + Arc::new(UInt8Array::from(inner_split_dim)), + Arc::new(Float64Array::from(inner_split_value)), + Arc::new(UInt32Array::from(inner_left_child)), + Arc::new(UInt32Array::from(inner_right_child)), + ], + ) + .unwrap(); + + // Serialize leaf nodes + let mut leaf_node_ids = Vec::new(); + let mut leaf_min_x = Vec::new(); + let mut leaf_min_y = Vec::new(); + let mut leaf_max_x = Vec::new(); + let mut leaf_max_y = Vec::new(); + let mut leaf_file_ids = Vec::new(); + let mut leaf_row_offsets = Vec::new(); + let mut leaf_num_rows = Vec::new(); + + for (idx, node) in &leaf_nodes { + if let BKDNode::Leaf(leaf) = node { + leaf_node_ids.push(*idx); + leaf_min_x.push(leaf.bounds[0]); + leaf_min_y.push(leaf.bounds[1]); + leaf_max_x.push(leaf.bounds[2]); + leaf_max_y.push(leaf.bounds[3]); + leaf_file_ids.push(leaf.file_id); + leaf_row_offsets.push(leaf.row_offset); + leaf_num_rows.push(leaf.num_rows); + } + } + + let leaf_batch = RecordBatch::try_new( + leaf_node_schema(), + vec![ + Arc::new(UInt32Array::from(leaf_node_ids)), + Arc::new(Float64Array::from(leaf_min_x)), + Arc::new(Float64Array::from(leaf_min_y)), + Arc::new(Float64Array::from(leaf_max_x)), + Arc::new(Float64Array::from(leaf_max_y)), + Arc::new(UInt32Array::from(leaf_file_ids)), + Arc::new(UInt64Array::from(leaf_row_offsets)), + Arc::new(UInt64Array::from(leaf_num_rows)), + ], + ) + .unwrap(); + + // Deserialize + let tree = BKDTreeLookup::from_record_batches(inner_batch, leaf_batch).unwrap(); + + // Verify structure + assert_eq!(tree.nodes.len(), 3); + assert_eq!(tree.root_id, 0); + assert_eq!(tree.num_leaves, 2); + + // Verify root (inner node) + match &tree.nodes[0] { + BKDNode::Inner(inner) => { + assert_eq!(inner.bounds, [-10.0, -10.0, 10.0, 10.0]); + assert_eq!(inner.split_dim, 0); + assert_eq!(inner.split_value, 0.0); + assert_eq!(inner.left_child, 1); + assert_eq!(inner.right_child, 2); + } + _ => panic!("Expected inner node at index 0"), + } + + // Verify left leaf + match &tree.nodes[1] { + BKDNode::Leaf(leaf) => { + assert_eq!(leaf.bounds, [-10.0, -10.0, 0.0, 10.0]); + assert_eq!(leaf.file_id, 0); + assert_eq!(leaf.row_offset, 0); + assert_eq!(leaf.num_rows, 100); + } + _ => panic!("Expected leaf node at index 1"), + } + + // Verify right leaf + match &tree.nodes[2] { + BKDNode::Leaf(leaf) => { + assert_eq!(leaf.bounds, [0.0, -10.0, 10.0, 10.0]); + assert_eq!(leaf.file_id, 0); + assert_eq!(leaf.row_offset, 100); + assert_eq!(leaf.num_rows, 100); + } + _ => panic!("Expected leaf node at index 2"), + } + } + + /// Test serialization of empty tree + #[test] + fn test_serialize_deserialize_empty_tree() { + let inner_batch = RecordBatch::new_empty(inner_node_schema()); + let leaf_batch = RecordBatch::new_empty(leaf_node_schema()); + + let tree = BKDTreeLookup::from_record_batches(inner_batch, leaf_batch).unwrap(); + + assert_eq!(tree.nodes.len(), 0); + assert_eq!(tree.num_leaves, 0); + } + + /// Test bbox intersection logic + #[test] + fn test_bboxes_intersect() { + // Overlapping boxes + assert!(bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[5.0, 5.0, 15.0, 15.0] + )); + + // Touching boxes + assert!(bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[10.0, 0.0, 20.0, 10.0] + )); + + // Fully contained + assert!(bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[2.0, 2.0, 8.0, 8.0] + )); + + // Non-overlapping (left) + assert!(!bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[-20.0, 0.0, -10.0, 10.0] + )); + + // Non-overlapping (above) + assert!(!bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[0.0, 20.0, 10.0, 30.0] + )); + } + + /// Test point in bbox logic + #[test] + fn test_point_in_bbox() { + let bbox = [0.0, 0.0, 10.0, 10.0]; + + // Inside + assert!(point_in_bbox(5.0, 5.0, &bbox)); + + // On boundary + assert!(point_in_bbox(0.0, 0.0, &bbox)); + assert!(point_in_bbox(10.0, 10.0, &bbox)); + + // Outside + assert!(!point_in_bbox(-1.0, 5.0, &bbox)); + assert!(!point_in_bbox(11.0, 5.0, &bbox)); + assert!(!point_in_bbox(5.0, -1.0, &bbox)); + assert!(!point_in_bbox(5.0, 11.0, &bbox)); + } + + /// Test find_intersecting_leaves with simple tree + #[test] + fn test_find_intersecting_leaves() { + // Create a simple tree: 1 inner node with 2 leaf children + let nodes = vec![ + BKDNode::Inner(BKDInnerNode { + bounds: [-10.0, -10.0, 10.0, 10.0], + split_dim: 0, + split_value: 0.0, + left_child: 1, + right_child: 2, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [-10.0, -10.0, 0.0, 10.0], + file_id: 0, + row_offset: 0, + num_rows: 100, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [0.0, -10.0, 10.0, 10.0], + file_id: 0, + row_offset: 100, + num_rows: 100, + }), + ]; + + let tree = BKDTreeLookup::new(nodes, 0, 2); + + // Query that intersects only left leaf + let leaves = tree + .find_intersecting_leaves([-10.0, -10.0, -5.0, 10.0]) + .unwrap(); + assert_eq!(leaves.len(), 1); + assert_eq!(leaves[0].file_id, 0); + assert_eq!(leaves[0].row_offset, 0); + + // Query that intersects only right leaf + let leaves = tree + .find_intersecting_leaves([5.0, -10.0, 10.0, 10.0]) + .unwrap(); + assert_eq!(leaves.len(), 1); + assert_eq!(leaves[0].file_id, 0); + assert_eq!(leaves[0].row_offset, 100); + + // Query that intersects both leaves + let leaves = tree + .find_intersecting_leaves([-5.0, -10.0, 5.0, 10.0]) + .unwrap(); + assert_eq!(leaves.len(), 2); + + // Query that intersects no leaves + let leaves = tree + .find_intersecting_leaves([20.0, 20.0, 30.0, 30.0]) + .unwrap(); + assert_eq!(leaves.len(), 0); + } + + /// Test tree building with small dataset + #[test] + fn test_build_tree_small() { + let mut points = vec![ + (-5.0, -5.0, 0), + (-4.0, -4.0, 1), + (4.0, 4.0, 2), + (5.0, 5.0, 3), + ]; + + let builder = BKDTreeBuilder::new(2); // leaf_size = 2 + let (nodes, batches) = builder.build(&mut points, 5).unwrap(); + + // Should have: 1 root + 2 leaves = 3 nodes + assert_eq!(nodes.len(), 3); + assert_eq!(batches.len(), 2); + + // Verify root is inner node + assert!(matches!(nodes[0], BKDNode::Inner(_))); + + // Verify we have 2 leaf nodes + let leaf_count = nodes.iter().filter(|n| n.is_leaf()).count(); + assert_eq!(leaf_count, 2); + + // Verify each batch has correct size + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 2); + } +} diff --git a/rust/lance-index/src/scalar/geo/geoindex.rs b/rust/lance-index/src/scalar/geo/geoindex.rs new file mode 100644 index 00000000000..e12d6cda3c1 --- /dev/null +++ b/rust/lance-index/src/scalar/geo/geoindex.rs @@ -0,0 +1,1677 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Geo Index +//! +//! Geo indices are spatial database structures for efficient spatial queries. +//! They enable efficient filtering by location-based predicates. +//! +//! ## Requirements +//! +//! Geo indices can only be created on fields with GeoArrow metadata. The field must: +//! - Be a Struct data type +//! - Have `ARROW:extension:name` metadata starting with `geoarrow.` (e.g., `geoarrow.point`, `geoarrow.polygon`) +//! +//! ## Query Support +//! +//! Geo indices are "inexact" filters - they can definitively exclude regions but may include +//! false positives that require rechecking. +//! + +use crate::pbold; +use super::bkd::{self, BKDTreeBuilder, BKDTreeLookup, point_in_bbox, BKDNode, BKDLeafNode}; +use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser}; +use crate::scalar::registry::{ + ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest, +}; +use crate::scalar::{ + BuiltinIndexType, CreatedIndex, GeoQuery, ScalarIndexParams, UpdateCriteria, +}; +use crate::Any; +use futures::TryStreamExt; +use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; +use lance_core::{ROW_ADDR, ROW_ID}; +use serde::{Deserialize, Serialize}; + +use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt8Array}; +use arrow_array::cast::AsArray; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::execution::SendableRecordBatchStream; +use std::{collections::HashMap, sync::Arc}; + +use crate::scalar::{AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; +use crate::scalar::FragReuseIndex; +use crate::vector::VectorIndex; +use crate::{Index, IndexType}; +use async_trait::async_trait; +use deepsize::DeepSizeOf; +use lance_core::Result; +use lance_core::{utils::mask::RowIdTreeMap, Error}; +use roaring::RoaringBitmap; +use snafu::location; + +const BKD_TREE_INNER_FILENAME: &str = "bkd_tree_inner.lance"; +const BKD_TREE_LEAF_FILENAME: &str = "bkd_tree_leaf.lance"; +const LEAF_GROUP_PREFIX: &str = "leaf_group_"; +const DEFAULT_BATCHES_PER_LEAF_FILE: u32 = 5; // Default number of leaf batches per file +const GEO_INDEX_VERSION: u32 = 0; +const MAX_POINTS_PER_LEAF_META_KEY: &str = "max_points_per_leaf"; +const BATCHES_PER_FILE_META_KEY: &str = "batches_per_file"; +const DEFAULT_MAX_POINTS_PER_LEAF: u32 = 100; // for test +// const DEFAULT_MAX_POINTS_PER_LEAF: u32 = 1024; // for production + +/// Get the file name for a leaf group +fn leaf_group_filename(group_id: u32) -> String { + format!("{}{}.lance", LEAF_GROUP_PREFIX, group_id) +} + +/// Lazy reader for BKD leaf file +#[derive(Clone)] +struct LazyIndexReader { + index_reader: Arc>>>, + store: Arc, + filename: String, +} + +impl LazyIndexReader { + fn new(store: Arc, filename: &str) -> Self { + Self { + index_reader: Arc::new(tokio::sync::Mutex::new(None)), + store, + filename: filename.to_string(), + } + } + + async fn get(&self) -> Result> { + let mut reader = self.index_reader.lock().await; + if reader.is_none() { + let r = self.store.open_index_file(&self.filename).await?; + *reader = Some(r); + } + Ok(reader.as_ref().unwrap().clone()) + } +} + +/// Cache key for BKD leaf nodes +#[derive(Debug, Clone)] +struct BKDLeafKey { + leaf_id: u32, +} + +impl CacheKey for BKDLeafKey { + type ValueType = CachedLeafData; + + fn key(&self) -> std::borrow::Cow<'_, str> { + format!("bkd-leaf-{}", self.leaf_id).into() + } +} + +/// Cached leaf data +#[derive(Debug, Clone)] +struct CachedLeafData(RecordBatch); + +impl DeepSizeOf for CachedLeafData { + fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { + // Approximate size of RecordBatch + self.0.get_array_memory_size() + } +} + +impl CachedLeafData { + fn new(batch: RecordBatch) -> Self { + Self(batch) + } + + fn into_inner(self) -> RecordBatch { + self.0 + } +} + +/// Geo index +pub struct GeoIndex { + data_type: DataType, + store: Arc, + fri: Option>, + index_cache: WeakLanceCache, + bkd_tree: Arc, + max_points_per_leaf: u32, +} + +impl std::fmt::Debug for GeoIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GeoIndex") + .field("data_type", &self.data_type) + .field("store", &self.store) + .field("fri", &self.fri) + .field("index_cache", &self.index_cache) + .field("bkd_tree", &self.bkd_tree) + .field("max_points_per_leaf", &self.max_points_per_leaf) + .finish() + } +} + +impl DeepSizeOf for GeoIndex { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.bkd_tree.deep_size_of_children(context) + self.store.deep_size_of_children(context) + } +} + +impl GeoIndex { + /// Load the geo index from storage + async fn load( + store: Arc, + fri: Option>, + index_cache: &LanceCache, + ) -> Result> + where + Self: Sized, + { + // Load inner nodes + let inner_file = store.open_index_file(BKD_TREE_INNER_FILENAME).await?; + let inner_data = inner_file + .read_range(0..inner_file.num_rows(), None) + .await?; + + // Load leaf metadata + let leaf_file = store.open_index_file(BKD_TREE_LEAF_FILENAME).await?; + let leaf_data = leaf_file + .read_range(0..leaf_file.num_rows(), None) + .await?; + + // Deserialize tree structure from both files + let bkd_tree = BKDTreeLookup::from_record_batches(inner_data, leaf_data)?; + + // Extract metadata from inner file + let schema = inner_file.schema(); + let max_points_per_leaf = schema + .metadata + .get(MAX_POINTS_PER_LEAF_META_KEY) + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_MAX_POINTS_PER_LEAF); + + // Get data type from schema + let data_type = schema.fields[0].data_type().clone(); + + Ok(Arc::new(Self { + data_type, + store, + fri, + index_cache: WeakLanceCache::from(index_cache), + bkd_tree: Arc::new(bkd_tree), + max_points_per_leaf, + })) + } + + /// Load a specific leaf from storage + async fn load_leaf( + &self, + leaf: &BKDLeafNode, + metrics: &dyn MetricsCollector, + ) -> Result { + let file_id = leaf.file_id; + let row_offset = leaf.row_offset; + let num_rows = leaf.num_rows; + + // Use (file_id, row_offset) as cache key + // Combine file_id and row_offset into a single u32 (file_id should be small) + let cache_key = BKDLeafKey { leaf_id: file_id * 100_000 + (row_offset as u32) }; + let store = self.store.clone(); + + let cached = self + .index_cache + .get_or_insert_with_key(cache_key, move || async move { + metrics.record_part_load(); + + let filename = leaf_group_filename(file_id); + + // Open the leaf group file and read the specific row range + let reader = store.open_index_file(&filename).await?; + let batch = reader.read_range( + row_offset as usize..(row_offset + num_rows) as usize, + None + ).await?; + + Ok(CachedLeafData::new(batch)) + }) + .await?; + + Ok(cached.as_ref().clone().into_inner()) + } + + /// Search a specific leaf for points within the query bbox + async fn search_leaf( + &self, + leaf: &BKDLeafNode, + query_bbox: [f64; 4], + metrics: &dyn MetricsCollector, + ) -> Result { + let leaf_data = self.load_leaf(leaf, metrics).await?; + + let file_id = leaf.file_id; + let row_offset = leaf.row_offset; + let num_rows = leaf.num_rows; + + // Filter points within this leaf + let mut row_ids = RowIdTreeMap::new(); + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + let x = x_array.value(i); + let y = y_array.value(i); + let row_id = row_id_array.value(i); + + if point_in_bbox(x, y, &query_bbox) { + row_ids.insert(row_id); + } + } + + Ok(row_ids) + } +} + +#[async_trait] +impl Index for GeoIndex { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_index(self: Arc) -> Arc { + self + } + + fn as_vector_index(self: Arc) -> Result> { + Err(Error::InvalidInput { + source: "GeoIndex is not a vector index".into(), + location: location!(), + }) + } + + async fn prewarm(&self) -> Result<()> { + Ok(()) + } + + fn statistics(&self) -> Result { + Ok(serde_json::json!({ + "type": "geo", + })) + } + + fn index_type(&self) -> IndexType { + IndexType::Geo + } + + async fn calculate_included_frags(&self) -> Result { + let frag_ids = RoaringBitmap::new(); + Ok(frag_ids) + } +} + +#[async_trait] +impl ScalarIndex for GeoIndex { + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { + let geo_query = query.as_any().downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Geo index only supports GeoQuery".into(), + location: location!(), + })?; + + match geo_query { + GeoQuery::Intersects(min_x, min_y, max_x, max_y) => { + let query_bbox = [*min_x, *min_y, *max_x, *max_y]; + + // Step 1: Find intersecting leaves using in-memory tree traversal + let leaves = self.bkd_tree.find_intersecting_leaves(query_bbox)?; + + // Step 2: Lazy-load and filter each leaf + let mut all_row_ids = RowIdTreeMap::new(); + + for leaf_node in &leaves { + let leaf_row_ids = self + .search_leaf(leaf_node, query_bbox, metrics) + .await?; + // Collect row IDs from the leaf and add them to the result set + let row_ids: Option> = leaf_row_ids.row_ids() + .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); + if let Some(row_ids) = row_ids { + all_row_ids.extend(row_ids); + } + } + + // We return Exact because we already filtered points in search_leaf + Ok(SearchResult::Exact(all_row_ids)) + } + } + } + + fn can_remap(&self) -> bool { + false + } + + /// Remap the row ids, creating a new remapped version of this index in `dest_store` + async fn remap( + &self, + _mapping: &HashMap>, + _dest_store: &dyn IndexStore, + ) -> Result { + Err(Error::InvalidInput { + source: "GeoIndex does not support remap".into(), + location: location!(), + }) + } + + /// Add the new data , creating an updated version of the index in `dest_store` + async fn update( + &self, + _new_data: SendableRecordBatchStream, + _dest_store: &dyn IndexStore, + ) -> Result { + Err(Error::InvalidInput { + source: "GeoIndex does not support update".into(), + location: location!(), + }) + } + + fn update_criteria(&self) -> UpdateCriteria { + UpdateCriteria::only_new_data( + TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), + ) + } + + fn derive_index_params(&self) -> Result { + let params = serde_json::to_value(GeoIndexBuilderParams::default())?; + Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::Geo).with_params(¶ms)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeoIndexBuilderParams { + #[serde(default = "default_max_points_per_leaf")] + pub max_points_per_leaf: u32, + #[serde(default = "default_batches_per_file")] + pub batches_per_file: u32, +} + +fn default_max_points_per_leaf() -> u32 { + DEFAULT_MAX_POINTS_PER_LEAF +} + +fn default_batches_per_file() -> u32 { + DEFAULT_BATCHES_PER_LEAF_FILE +} + +impl Default for GeoIndexBuilderParams { + fn default() -> Self { + Self { + max_points_per_leaf: default_max_points_per_leaf(), + batches_per_file: default_batches_per_file(), + } + } +} + +impl GeoIndexBuilderParams { + pub fn new() -> Self { + Self::default() + } + + pub fn with_max_points_per_leaf(mut self, max_points_per_leaf: u32) -> Self { + self.max_points_per_leaf = max_points_per_leaf; + self + } +} + +// A builder for geo index +pub struct GeoIndexBuilder { + options: GeoIndexBuilderParams, + items_type: DataType, + // Accumulated points: (x, y, row_id) + points: Vec<(f64, f64, u64)>, +} + +impl GeoIndexBuilder { + pub fn try_new(options: GeoIndexBuilderParams, items_type: DataType) -> Result { + Ok(Self { + options, + items_type, + points: Vec::new(), + }) + } + + pub async fn train(&mut self, batches_source: SendableRecordBatchStream) -> Result<()> { + assert!(batches_source.schema().field_with_name(ROW_ADDR).is_ok()); + + let mut batches_source = batches_source; + + while let Some(batch) = batches_source.try_next().await? { + // Extract GeoArrow point coordinates + let geom_array = batch.column(0).as_any().downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Expected Struct array for GeoArrow data".into(), + location: location!(), + })?; + + let x_array = geom_array + .column(0) + .as_primitive::(); + let y_array = geom_array + .column(1) + .as_primitive::(); + let row_ids = batch + .column_by_name(ROW_ADDR) + .unwrap() + .as_primitive::(); + + for i in 0..batch.num_rows() { + self.points.push(( + x_array.value(i), + y_array.value(i), + row_ids.value(i), + )); + } + } + + log::debug!("Accumulated {} points for BKD tree", self.points.len()); + + Ok(()) + } + + pub async fn write_index(mut self, index_store: &dyn IndexStore) -> Result<()> { + if self.points.is_empty() { + return Ok(()); + } + + // Build BKD tree + let (tree_nodes, leaf_batches) = self.build_bkd_tree()?; + + // Write tree structure to separate inner and leaf files + let (inner_batch, leaf_metadata_batch) = self.serialize_tree_nodes(&tree_nodes)?; + + // Write inner nodes + let mut inner_file = index_store + .new_index_file(BKD_TREE_INNER_FILENAME, inner_batch.schema()) + .await?; + inner_file.write_record_batch(inner_batch).await?; + inner_file + .finish_with_metadata(HashMap::from([ + (MAX_POINTS_PER_LEAF_META_KEY.to_string(), self.options.max_points_per_leaf.to_string()), + (BATCHES_PER_FILE_META_KEY.to_string(), self.options.batches_per_file.to_string()), + ])) + .await?; + + // Write leaf metadata + let mut leaf_meta_file = index_store + .new_index_file(BKD_TREE_LEAF_FILENAME, leaf_metadata_batch.schema()) + .await?; + leaf_meta_file.write_record_batch(leaf_metadata_batch).await?; + leaf_meta_file.finish().await?; + + // Write actual leaf data grouped into files (multiple batches per file) + let leaf_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + + let batches_per_file = self.options.batches_per_file; + let num_groups = (leaf_batches.len() as u32 + batches_per_file - 1) / batches_per_file; + + for group_id in 0..num_groups { + let start_idx = (group_id * batches_per_file) as usize; + let end_idx = ((group_id + 1) * batches_per_file).min(leaf_batches.len() as u32) as usize; + let group_batches = &leaf_batches[start_idx..end_idx]; + + let filename = leaf_group_filename(group_id); + + let mut leaf_file = index_store + .new_index_file(&filename, leaf_schema.clone()) + .await?; + + for leaf_batch in group_batches.iter() { + leaf_file.write_record_batch(leaf_batch.clone()).await?; + } + + leaf_file.finish().await?; + } + + log::debug!( + "Wrote BKD tree with {} nodes", + tree_nodes.len() + ); + + Ok(()) + } + + // Serialize tree nodes to separate inner and leaf RecordBatches + fn serialize_tree_nodes(&self, nodes: &[BKDNode]) -> Result<(RecordBatch, RecordBatch)> { + + // Separate inner and leaf nodes with their indices + let mut inner_nodes = Vec::new(); + let mut leaf_nodes = Vec::new(); + + for (idx, node) in nodes.iter().enumerate() { + match node { + BKDNode::Inner(_) => inner_nodes.push((idx as u32, node)), + BKDNode::Leaf(_) => leaf_nodes.push((idx as u32, node)), + } + } + + // Serialize inner nodes + let inner_batch = Self::serialize_inner_nodes(&inner_nodes)?; + + // Serialize leaf nodes + let leaf_batch = Self::serialize_leaf_nodes(&leaf_nodes)?; + + Ok((inner_batch, leaf_batch)) + } + + fn serialize_inner_nodes(nodes: &[(u32, &BKDNode)]) -> Result { + + let mut node_id_vals = Vec::with_capacity(nodes.len()); + let mut min_x_vals = Vec::with_capacity(nodes.len()); + let mut min_y_vals = Vec::with_capacity(nodes.len()); + let mut max_x_vals = Vec::with_capacity(nodes.len()); + let mut max_y_vals = Vec::with_capacity(nodes.len()); + let mut split_dim_vals = Vec::with_capacity(nodes.len()); + let mut split_value_vals = Vec::with_capacity(nodes.len()); + let mut left_child_vals = Vec::with_capacity(nodes.len()); + let mut right_child_vals = Vec::with_capacity(nodes.len()); + + for (idx, node) in nodes { + if let BKDNode::Inner(inner) = node { + node_id_vals.push(*idx); + min_x_vals.push(inner.bounds[0]); + min_y_vals.push(inner.bounds[1]); + max_x_vals.push(inner.bounds[2]); + max_y_vals.push(inner.bounds[3]); + split_dim_vals.push(inner.split_dim); + split_value_vals.push(inner.split_value); + left_child_vals.push(inner.left_child); + right_child_vals.push(inner.right_child); + } + } + + let schema = bkd::inner_node_schema(); + + let columns: Vec = vec![ + Arc::new(UInt32Array::from(node_id_vals)), + Arc::new(Float64Array::from(min_x_vals)), + Arc::new(Float64Array::from(min_y_vals)), + Arc::new(Float64Array::from(max_x_vals)), + Arc::new(Float64Array::from(max_y_vals)), + Arc::new(UInt8Array::from(split_dim_vals)), + Arc::new(Float64Array::from(split_value_vals)), + Arc::new(UInt32Array::from(left_child_vals)), + Arc::new(UInt32Array::from(right_child_vals)), + ]; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + fn serialize_leaf_nodes(nodes: &[(u32, &BKDNode)]) -> Result { + use arrow_array::UInt64Array; + + let mut node_id_vals = Vec::with_capacity(nodes.len()); + let mut min_x_vals = Vec::with_capacity(nodes.len()); + let mut min_y_vals = Vec::with_capacity(nodes.len()); + let mut max_x_vals = Vec::with_capacity(nodes.len()); + let mut max_y_vals = Vec::with_capacity(nodes.len()); + let mut file_id_vals = Vec::with_capacity(nodes.len()); + let mut row_offset_vals = Vec::with_capacity(nodes.len()); + let mut num_rows_vals = Vec::with_capacity(nodes.len()); + + for (idx, node) in nodes { + if let BKDNode::Leaf(leaf) = node { + node_id_vals.push(*idx); + min_x_vals.push(leaf.bounds[0]); + min_y_vals.push(leaf.bounds[1]); + max_x_vals.push(leaf.bounds[2]); + max_y_vals.push(leaf.bounds[3]); + file_id_vals.push(leaf.file_id); + row_offset_vals.push(leaf.row_offset); + num_rows_vals.push(leaf.num_rows); + } + } + + let schema = bkd::leaf_node_schema(); + + let columns: Vec = vec![ + Arc::new(UInt32Array::from(node_id_vals)), + Arc::new(Float64Array::from(min_x_vals)), + Arc::new(Float64Array::from(min_y_vals)), + Arc::new(Float64Array::from(max_x_vals)), + Arc::new(Float64Array::from(max_y_vals)), + Arc::new(UInt32Array::from(file_id_vals)), + Arc::new(UInt64Array::from(row_offset_vals)), + Arc::new(UInt64Array::from(num_rows_vals)), + ]; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + // Build BKD tree using the BKDTreeBuilder + fn build_bkd_tree(&mut self) -> Result<(Vec, Vec)> { + let builder = BKDTreeBuilder::new(self.options.max_points_per_leaf as usize); + builder.build(&mut self.points, self.options.batches_per_file) + } +} + +#[derive(Debug, Default)] +pub struct GeoIndexPlugin; + +impl GeoIndexPlugin { + async fn train_geo_index( + batches_source: SendableRecordBatchStream, + index_store: &dyn IndexStore, + options: Option, + ) -> Result<()> { + let value_type = batches_source.schema().field(0).data_type().clone(); + + let mut builder = GeoIndexBuilder::try_new(options.unwrap_or_default(), value_type)?; + + builder.train(batches_source).await?; + + builder.write_index(index_store).await?; + Ok(()) + } +} + +pub struct GeoIndexTrainingRequest { + pub params: GeoIndexBuilderParams, + pub criteria: TrainingCriteria, +} + +impl GeoIndexTrainingRequest { + pub fn new(params: GeoIndexBuilderParams) -> Self { + Self { + params, + criteria: TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), + } + } +} + +impl TrainingRequest for GeoIndexTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[async_trait] +impl ScalarIndexPlugin for GeoIndexPlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + // Check that the field is a Struct type + if !matches!(field.data_type(), DataType::Struct(_)) { + return Err(Error::InvalidInput { + source: "A geo index can only be created on a Struct field.".into(), + location: location!(), + }); + } + + // Check for GeoArrow metadata + let is_geoarrow = field + .metadata() + .get("ARROW:extension:name") + .map(|name| name.starts_with("geoarrow.")) + .unwrap_or(false); + + if !is_geoarrow { + return Err(Error::InvalidInput { + source: format!( + "Geo index requires GeoArrow metadata on field '{}'. \ + The field must have 'ARROW:extension:name' metadata starting with 'geoarrow.'", + field.name() + ) + .into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + + Ok(Box::new(GeoIndexTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true // We do exact point-in-bbox filtering in search_leaf + } + + fn version(&self) -> u32 { + GEO_INDEX_VERSION + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(GeoQueryParser::new(index_name))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + if fragment_ids.is_some() { + return Err(Error::InvalidInput { + source: "Geo index does not support fragment training".into(), + location: location!(), + }); + } + + let request = (request as Box) + .downcast::() + .map_err(|_| Error::InvalidInput { + source: "must provide training request created by new_training_request".into(), + location: location!(), + })?; + Self::train_geo_index(data, index_store, Some(request.params)).await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&pbold::GeoIndexDetails::default()) + .unwrap(), + index_version: GEO_INDEX_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result> { + Ok(GeoIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Fields, Schema}; + use datafusion::execution::SendableRecordBatchStream; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use futures::stream; + use lance_core::cache::LanceCache; + use lance_core::utils::tempfile::TempObjDir; + use lance_core::ROW_ADDR; + use lance_io::object_store::ObjectStore; + + use crate::scalar::lance_format::LanceIndexStore; + + fn create_test_store() -> Arc { + let tmpdir = TempObjDir::default(); + + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + return test_store; + } + + /// Validates the BKD tree structure recursively + async fn validate_bkd_tree(index: &GeoIndex) -> Result<()> { + use crate::metrics::NoOpMetricsCollector; + + let tree = &index.bkd_tree; + + // Verify root exists + if tree.nodes.is_empty() { + return Err(Error::InvalidInput { + source: "BKD tree has no nodes".into(), + location: location!(), + }); + } + + let root_id = tree.root_id as usize; + if root_id >= tree.nodes.len() { + return Err(Error::InvalidInput { + source: format!("Root node id {} out of bounds (tree has {} nodes)", root_id, tree.nodes.len()).into(), + location: location!(), + }); + } + + // Count leaves and validate structure + let mut visited = vec![false; tree.nodes.len()]; + let mut leaf_count = 0u32; + + // Start with None for expected_split_dim - root can use any dimension + // (typically starts with dimension 0, but we don't enforce it) + let metrics = NoOpMetricsCollector {}; + validate_node_recursive(index, tree, root_id as u32, &mut visited, &mut leaf_count, None, None, &metrics).await?; + + // Verify all leaves were counted + if leaf_count != tree.num_leaves { + return Err(Error::InvalidInput { + source: format!("Leaf count mismatch: found {} leaves but tree claims {}", leaf_count, tree.num_leaves).into(), + location: location!(), + }); + } + + // Check for orphaned nodes (nodes not reachable from root) + for (idx, was_visited) in visited.iter().enumerate() { + if !was_visited { + return Err(Error::InvalidInput { + source: format!("Node {} is orphaned (not reachable from root)", idx).into(), + location: location!(), + }); + } + } + + Ok(()) + } + + /// Helper function to recursively validate a node and its descendants + async fn validate_node_recursive( + index: &GeoIndex, + tree: &BKDTreeLookup, + node_id: u32, + visited: &mut Vec, + leaf_count: &mut u32, + parent_bounds: Option<[f64; 4]>, + parent_split: Option<(u8, f64, bool)>, // (split_dim, split_value, is_left_child) from parent + metrics: &dyn MetricsCollector, + ) -> Result<()> { + let node_idx = node_id as usize; + + // Check node exists + if node_idx >= tree.nodes.len() { + return Err(Error::InvalidInput { + source: format!("Node id {} out of bounds (tree has {} nodes)", node_id, tree.nodes.len()).into(), + location: location!(), + }); + } + + // Mark as visited + if visited[node_idx] { + return Err(Error::InvalidInput { + source: format!("Node {} visited multiple times (cycle detected)", node_id).into(), + location: location!(), + }); + } + visited[node_idx] = true; + + let node = &tree.nodes[node_idx]; + let bounds = node.bounds(); + + // Validate bounds are well-formed + if bounds[0] > bounds[2] || bounds[1] > bounds[3] { + return Err(Error::InvalidInput { + source: format!("Node {} has invalid bounds: [{}, {}, {}, {}]", + node_id, bounds[0], bounds[1], bounds[2], bounds[3]).into(), + location: location!(), + }); + } + + // Verify child bounds are within parent bounds + if let Some(parent_bounds) = parent_bounds { + if bounds[0] < parent_bounds[0] || bounds[1] < parent_bounds[1] || + bounds[2] > parent_bounds[2] || bounds[3] > parent_bounds[3] { + return Err(Error::InvalidInput { + source: format!( + "Node {} bounds [{}, {}, {}, {}] exceed parent bounds [{}, {}, {}, {}]", + node_id, bounds[0], bounds[1], bounds[2], bounds[3], + parent_bounds[0], parent_bounds[1], parent_bounds[2], parent_bounds[3] + ).into(), + location: location!(), + }); + } + } + + match node { + BKDNode::Inner(inner) => { + // Validate split dimension + if inner.split_dim > 1 { + return Err(Error::InvalidInput { + source: format!("Node {} has invalid split_dim: {} (must be 0 or 1)", node_id, inner.split_dim).into(), + location: location!(), + }); + } + + // Validate split value is within bounds + let min_val = if inner.split_dim == 0 { bounds[0] } else { bounds[1] }; + let max_val = if inner.split_dim == 0 { bounds[2] } else { bounds[3] }; + if inner.split_value < min_val || inner.split_value > max_val { + return Err(Error::InvalidInput { + source: format!( + "Node {} split_value {} is outside dimension bounds [{}, {}]", + node_id, inner.split_value, min_val, max_val + ).into(), + location: location!(), + }); + } + + // Validate that children are properly sorted by split dimension + // If points are sorted, left_max <= split_value <= right_min + // Which means left_max <= right_min (no overlap) + let left_node = &tree.nodes[inner.left_child as usize]; + let right_node = &tree.nodes[inner.right_child as usize]; + + let left_bounds = left_node.bounds(); + let right_bounds = right_node.bounds(); + + let (left_max, right_min) = if inner.split_dim == 0 { + (left_bounds[2], right_bounds[0]) // max_x of left, min_x of right + } else { + (left_bounds[3], right_bounds[1]) // max_y of left, min_y of right + }; + + // Simple check: left_max should not be greater than right_min + // If it is, points were not properly sorted before splitting! + if left_max > right_min { + return Err(Error::InvalidInput { + source: format!( + "Node {} (split_dim={}, split_value={}): left child max {}={} > right child min {}={}. \ + Points were not properly sorted before splitting!", + node_id, inner.split_dim, inner.split_value, + if inner.split_dim == 0 { "x" } else { "y" }, + left_max, + if inner.split_dim == 0 { "x" } else { "y" }, + right_min + ).into(), + location: location!(), + }); + } + + // Recursively validate children + Box::pin(validate_node_recursive( + index, + tree, + inner.left_child, + visited, + leaf_count, + Some(bounds), + Some((inner.split_dim, inner.split_value, true)), + metrics, + )).await?; + Box::pin(validate_node_recursive( + index, + tree, + inner.right_child, + visited, + leaf_count, + Some(bounds), + Some((inner.split_dim, inner.split_value, false)), + metrics, + )).await?; + } + bkd::BKDNode::Leaf(leaf) => { + // Validate leaf data + if leaf.num_rows == 0 { + return Err(Error::InvalidInput { + source: format!("Leaf node {} has zero rows", node_id).into(), + location: location!(), + }); + } + + // No extra validation needed for leaves - will be checked from parent + + *leaf_count += 1; + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_geo_index_with_custom_max_points_per_leaf() { + use rand::{Rng, SeedableRng}; + use rand::rngs::StdRng; + + // Test with different max points per leaf + for max_points_per_leaf in [10, 50, 100, 200] { + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 500 RANDOM points (not sequential!) to catch sorting bugs + let mut rng = StdRng::seed_from_u64(42); + for i in 0..500 { + let x = rng.random_range(0.0..100.0); + let y = rng.random_range(0.0..100.0); + builder.points.push((x, y, i as u64)); + } + + // Write index + builder.write_index(test_store.as_ref()).await.unwrap(); + + // Load index and verify + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .expect("Failed to load GeoIndex"); + + assert_eq!(index.max_points_per_leaf, max_points_per_leaf); + + + + // Validate tree structure + validate_bkd_tree(&index).await + .expect(&format!("BKD tree validation failed for max_points_per_leaf={}", max_points_per_leaf)); + } + } + + #[tokio::test] + async fn test_geo_index_with_custom_batches_per_file() { + // Test with different batches_per_file configurations + for batches_per_file in [1, 3, 5, 10, 20] { + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 50, + batches_per_file, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 500 points (BKD spatial partitioning determines actual leaf count) + for i in 0..500 { + let x = (i % 100) as f64; + let y = (i / 100) as f64; + builder.points.push((x, y, i as u64)); + } + + // Write index + builder.write_index(test_store.as_ref()).await.unwrap(); + + // Load index and verify + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .expect("Failed to load GeoIndex"); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect(&format!("BKD tree validation failed for batches_per_file={}", batches_per_file)); + + } + } + + #[tokio::test] + async fn test_geo_index_query_correctness_various_configs() { + use crate::metrics::NoOpMetricsCollector; + + // Test query correctness with different configurations + let configs = vec![ + (10, 1), // Small leaves, one per file + (50, 5), // Default-ish + (100, 10), // Larger leaves, many per file + ]; + + for (max_points_per_leaf, batches_per_file) in configs { + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf, + batches_per_file, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create a grid of points + for x in 0..20 { + for y in 0..20 { + let row_id = (x * 20 + y) as u64; + builder.points.push((x as f64, y as f64, row_id)); + } + } + + // Write and load index + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect(&format!("BKD tree validation failed for max_points_per_leaf={}, batches_per_file={}", max_points_per_leaf, batches_per_file)); + + // Query: bbox [5, 5, 10, 10] should return points in that region + let query = GeoQuery::Intersects(5.0, 5.0, 10.0, 10.0); + + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + // Should find points (5,5) to (10,10) = 6x6 = 36 points + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!(row_ids.len().unwrap_or(0), 36, + "Expected 36 points for config max_points_per_leaf={}, batches_per_file={}, got {}", + max_points_per_leaf, batches_per_file, row_ids.len().unwrap_or(0)); + + // Verify correct row IDs + for x in 5..=10 { + for y in 5..=10 { + let expected_row_id = (x * 20 + y) as u64; + assert!(row_ids.contains(expected_row_id), + "Missing row_id {} for point ({}, {})", + expected_row_id, x, y); + } + } + } + _ => panic!("Expected Exact search result"), + } + } + } + + #[tokio::test] + async fn test_geo_index_single_leaf() { + // Edge case: all points fit in single leaf + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 1000, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create only 50 points + for i in 0..50 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Should have only 1 leaf + assert_eq!(index.bkd_tree.num_leaves, 1); + assert_eq!(index.bkd_tree.nodes.len(), 1); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect("BKD tree validation failed for single leaf test"); + + // Single leaf should be in file 0 at offset 0 + match &index.bkd_tree.nodes[0] { + BKDNode::Leaf(leaf) => { + assert_eq!(leaf.file_id, 0); + assert_eq!(leaf.row_offset, 0); + assert_eq!(leaf.num_rows, 50); + } + _ => panic!("Expected leaf node"), + } + } + + #[tokio::test] + async fn test_geo_index_many_small_leaves() { + // Stress test: many small leaves, test file grouping + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 5, // Very small leaves + batches_per_file: 3, // Few batches per file + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 100 points (BKD spatial partitioning determines actual leaf count) + for i in 0..100 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // BKD tree creates more leaves due to spatial partitioning + assert_eq!(index.bkd_tree.num_leaves, 32); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect("BKD tree validation failed for many small leaves test"); + + // 32 leaves / 3 batches_per_file = 11 files (ceil) + let file_ids: std::collections::HashSet = index + .bkd_tree + .nodes + .iter() + .filter_map(|n| n.as_leaf()) + .map(|l| l.file_id) + .collect(); + + assert_eq!(file_ids.len(), 11); + + // Verify row offsets are cumulative within each file + let leaves: Vec<_> = index + .bkd_tree + .nodes + .iter() + .filter_map(|n| n.as_leaf()) + .collect(); + + for file_id in 0..11 { + let leaves_in_file: Vec<_> = leaves + .iter() + .filter(|l| l.file_id == file_id) + .collect(); + + let mut expected_offset = 0u64; + for leaf in leaves_in_file { + assert_eq!(leaf.row_offset, expected_offset, + "Incorrect offset in file {}", file_id); + expected_offset += leaf.num_rows; + } + } + } + + #[tokio::test] + async fn test_geo_index_data_integrity_after_serialization() { + // Verify every point written can be read back exactly + use crate::metrics::NoOpMetricsCollector; + + let configs = vec![ + (10, 1), // Small leaves, one per file + (50, 3), // Medium leaves, few per file + (100, 10), // Large leaves, many per file + ]; + + for (max_points_per_leaf, batches_per_file) in configs { + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf, + batches_per_file, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create test points with known values + let mut original_points = Vec::new(); + for i in 0..200 { + let x = (i % 50) as f64 + 0.123; // Non-integer to test precision + let y = (i / 50) as f64 + 0.456; + let row_id = i as u64; + original_points.push((x, y, row_id)); + builder.points.push((x, y, row_id)); + } + + // Write index + builder.write_index(test_store.as_ref()).await.unwrap(); + + // Load index + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Read back all points by querying each leaf + let metrics = NoOpMetricsCollector {}; + let mut recovered_points = std::collections::HashMap::new(); + + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + // Load the leaf data + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + + // Extract points from this leaf + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + let x = x_array.value(i); + let y = y_array.value(i); + let row_id = row_id_array.value(i); + recovered_points.insert(row_id, (x, y)); + } + } + + // Verify all points were recovered + assert_eq!(recovered_points.len(), original_points.len(), + "Lost points with config max_points_per_leaf={}, batches_per_file={}", + max_points_per_leaf, batches_per_file); + + // Verify each point matches exactly + for (original_x, original_y, row_id) in &original_points { + let (recovered_x, recovered_y) = recovered_points.get(row_id) + .expect(&format!("Missing row_id {} in recovered data", row_id)); + + assert_eq!(*recovered_x, *original_x, + "X coordinate mismatch for row_id {}: expected {}, got {}", + row_id, original_x, recovered_x); + assert_eq!(*recovered_y, *original_y, + "Y coordinate mismatch for row_id {}: expected {}, got {}", + row_id, original_y, recovered_y); + } + } + } + + #[tokio::test] + async fn test_geo_index_no_duplicate_points() { + // Ensure no points are duplicated during write/read + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 7, // Odd number to test edge cases + batches_per_file: 3, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 100 unique points + for i in 0..100 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Read all row_ids from all leaves + use crate::metrics::NoOpMetricsCollector; + let metrics = NoOpMetricsCollector {}; + let mut all_row_ids = Vec::new(); + + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + // Should have exactly 100 row_ids + assert_eq!(all_row_ids.len(), 100, "Wrong number of points recovered"); + + // All should be unique + let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); + assert_eq!(unique_row_ids.len(), 100, "Found duplicate points!"); + + // Should be exactly 0..99 + for i in 0..100 { + assert!(unique_row_ids.contains(&(i as u64)), + "Missing row_id {}", i); + } + } + + #[tokio::test] + async fn test_geo_index_lazy_loading() { + use crate::metrics::NoOpMetricsCollector; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + // Test that leaves are loaded lazily (not all at once) + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 100 points in a grid + for x in 0..10 { + for y in 0..10 { + builder.points.push((x as f64 * 10.0, y as f64 * 10.0, (x * 10 + y) as u64)); + } + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Query a small region - should only load relevant leaves, not all of them + let query = GeoQuery::Intersects(0.0, 0.0, 20.0, 20.0); + let metrics = NoOpMetricsCollector {}; + + // Find which leaves would be touched + let query_bbox = [0.0, 0.0, 20.0, 20.0]; + let intersecting_leaves = index.bkd_tree.find_intersecting_leaves(query_bbox).unwrap(); + + // Verify we're not loading ALL leaves for a small query + assert!( + intersecting_leaves.len() < index.bkd_tree.num_leaves as usize, + "Lazy loading test: small query should not touch all leaves! \ + Touched {}/{} leaves", + intersecting_leaves.len(), + index.bkd_tree.num_leaves + ); + + // Execute the query + let result = index.search(&query, &metrics).await.unwrap(); + + // Manually verify correctness: which points SHOULD be in bbox [0, 0, 20, 20]? + let mut expected_row_ids = std::collections::HashSet::new(); + for x in 0..10 { + for y in 0..10 { + let point_x = x as f64 * 10.0; + let point_y = y as f64 * 10.0; + // st_intersects: point is inside bbox if min_x <= x <= max_x && min_y <= y <= max_y + if point_x >= 0.0 && point_x <= 20.0 && point_y >= 0.0 && point_y <= 20.0 { + expected_row_ids.insert((x * 10 + y) as u64); + } + } + } + + // Verify results match our manual calculation + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_count = row_ids.len().unwrap_or(0) as usize; + + assert_eq!(actual_count, expected_row_ids.len(), + "Expected {} points, got {}", + expected_row_ids.len(), actual_count); + + // Verify each returned row_id is in our expected set + if let Some(iter) = row_ids.row_ids() { + for row_addr in iter { + let row_id = u64::from(row_addr); + assert!(expected_row_ids.contains(&row_id), + "Unexpected row_id {} in results", row_id); + } + } + + // Verify we didn't miss any expected points + if let Some(iter) = row_ids.row_ids() { + let found_ids: std::collections::HashSet = + iter.map(|addr| u64::from(addr)).collect(); + for expected_id in &expected_row_ids { + assert!(found_ids.contains(expected_id), + "Missing expected row_id {} in results", expected_id); + } + } + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + #[ignore] // Expensive test - run with: cargo test -- --ignored + async fn test_geo_index_large_scale_lazy_loading() { + use rand::{Rng, SeedableRng}; + use rand::rngs::StdRng; + use std::sync::atomic::{AtomicUsize, Ordering}; + + // Custom metrics collector to track I/O operations + struct LoadTracker { + part_loads: AtomicUsize, + } + impl crate::metrics::MetricsCollector for LoadTracker { + fn record_part_load(&self) { + self.part_loads.fetch_add(1, Ordering::Relaxed); + } + fn record_parts_loaded(&self, _num_parts: usize) {} + fn record_index_loads(&self, _num_indices: usize) {} + fn record_comparisons(&self, _num_comparisons: usize) {} + } + + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 1000, // Realistic leaf size + batches_per_file: 10, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 1 million random points across a 1000x1000 grid + let mut rng = StdRng::seed_from_u64(42); + for i in 0..1_000_000 { + let x = rng.random_range(0.0..1000.0); + let y = rng.random_range(0.0..1000.0); + builder.points.push((x, y, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Run 100 random queries with load tracking + let metrics = LoadTracker { part_loads: AtomicUsize::new(0) }; + let mut total_results = 0; + let mut total_leaves_touched = 0; + + for _i in 0..100 { + // Random query bbox with varying sizes (1x1 to 50x50 regions in 1000x1000 space) + let width = rng.random_range(1.0..50.0); + let height = rng.random_range(1.0..50.0); + let min_x = rng.random_range(0.0..(1000.0 - width)); + let min_y = rng.random_range(0.0..(1000.0 - height)); + let max_x = min_x + width; + let max_y = min_y + height; + + let query = GeoQuery::Intersects(min_x, min_y, max_x, max_y); + + // Count leaves touched + let query_bbox = [min_x, min_y, max_x, max_y]; + let intersecting_leaves = index.bkd_tree.find_intersecting_leaves(query_bbox).unwrap(); + total_leaves_touched += intersecting_leaves.len(); + + // Execute query + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let count = row_ids.len().unwrap_or(0); + total_results += count; + } + _ => panic!("Expected Exact search result"), + } + } + + let total_io_ops = metrics.part_loads.load(Ordering::Relaxed); + + // CRITICAL: Verify lazy loading is working! + // We should NOT load all leaves - only the ones intersecting query bboxes + assert!( + total_io_ops < index.bkd_tree.num_leaves as usize, + "āŒ LAZY LOADING FAILED: Loaded {} leaves but index only has {} leaves! \ + Should load much fewer than total.", + total_io_ops, index.bkd_tree.num_leaves + ); + + // Verify lazy loading is effective (< 10% of leaves touched on average per query) + let avg_leaves_touched = total_leaves_touched as f64 / 100.0; + let total_leaves = index.bkd_tree.num_leaves as f64; + assert!( + avg_leaves_touched < total_leaves * 0.1, + "āŒ Lazy loading ineffective: touching {:.1}% of leaves on average", + (avg_leaves_touched / total_leaves) * 100.0 + ); + } + + #[tokio::test] + async fn test_geo_index_points_in_correct_leaves() { + // Verify points are in leaves with correct bounding boxes + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create a grid of points + for x in 0..20 { + for y in 0..20 { + builder.points.push((x as f64, y as f64, (x * 20 + y) as u64)); + } + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Verify each point is within its leaf's bounding box + use crate::metrics::NoOpMetricsCollector; + let metrics = NoOpMetricsCollector {}; + + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + let x = x_array.value(i); + let y = y_array.value(i); + + // Point must be within leaf's bounding box + assert!(x >= leaf.bounds[0] && x <= leaf.bounds[2], + "Point x={} outside leaf bounds [{}, {}]", + x, leaf.bounds[0], leaf.bounds[2]); + assert!(y >= leaf.bounds[1] && y <= leaf.bounds[3], + "Point y={} outside leaf bounds [{}, {}]", + y, leaf.bounds[1], leaf.bounds[3]); + } + } + } +} + diff --git a/rust/lance-index/src/scalar/geo/mod.rs b/rust/lance-index/src/scalar/geo/mod.rs new file mode 100644 index 00000000000..ddf26b02724 --- /dev/null +++ b/rust/lance-index/src/scalar/geo/mod.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Geographic indexing module +//! +//! This module contains implementations for spatial/geographic indexing: +//! - BKD Tree: Block K-Dimensional tree for efficient spatial partitioning +//! - Geo Index: Geographic index built on top of BKD trees for GeoArrow data + +pub mod bkd; +pub mod geoindex; + +pub use bkd::*; +pub use geoindex::*; + diff --git a/rust/lance-index/src/scalar/geoindex.rs b/rust/lance-index/src/scalar/geoindex.rs deleted file mode 100644 index e3119565676..00000000000 --- a/rust/lance-index/src/scalar/geoindex.rs +++ /dev/null @@ -1,947 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! Geo Index -//! -//! Geo indices are spatial database structures for efficient spatial queries. -//! They enable efficient filtering by location-based predicates. -//! -//! ## Requirements -//! -//! Geo indices can only be created on fields with GeoArrow metadata. The field must: -//! - Be a Struct data type -//! - Have `ARROW:extension:name` metadata starting with `geoarrow.` (e.g., `geoarrow.point`, `geoarrow.polygon`) -//! -//! ## Query Support -//! -//! Geo indices are "inexact" filters - they can definitively exclude regions but may include -//! false positives that require rechecking. -//! - -use crate::pbold; -use crate::scalar::bkd::{BKDTreeBuilder, BKDTreeLookup, point_in_bbox}; -use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser}; -use crate::scalar::registry::{ - ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest, -}; -use crate::scalar::{ - BuiltinIndexType, CreatedIndex, GeoQuery, ScalarIndexParams, UpdateCriteria, -}; -use crate::Any; -use futures::TryStreamExt; -use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; -use lance_core::{ROW_ADDR, ROW_ID}; -use serde::{Deserialize, Serialize}; - -use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt64Array, UInt8Array}; -use arrow_array::cast::AsArray; -use arrow_schema::{DataType, Field, Schema}; -use datafusion::execution::SendableRecordBatchStream; -use std::{collections::HashMap, sync::Arc}; - -use super::{AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; -use crate::scalar::FragReuseIndex; -use crate::vector::VectorIndex; -use crate::{Index, IndexType}; -use async_trait::async_trait; -use deepsize::DeepSizeOf; -use lance_core::Result; -use lance_core::{utils::mask::RowIdTreeMap, Error}; -use roaring::RoaringBitmap; -use snafu::location; - -const BKD_TREE_INNER_FILENAME: &str = "bkd_tree_inner.lance"; -const BKD_TREE_LEAF_FILENAME: &str = "bkd_tree_leaf.lance"; -const LEAF_GROUP_PREFIX: &str = "leaf_group_"; -const DEFAULT_BATCHES_PER_LEAF_FILE: u32 = 5; // Default number of leaf batches per file -const GEO_INDEX_VERSION: u32 = 0; -const LEAF_SIZE_META_KEY: &str = "leaf_size"; -const BATCHES_PER_FILE_META_KEY: &str = "batches_per_file"; -const DEFAULT_LEAF_SIZE: u32 = 100; // for test -// const DEFAULT_LEAF_SIZE: u32 = 1024; // for production - -/// Get the file name for a leaf group -fn leaf_group_filename(group_id: u32) -> String { - format!("{}{}.lance", LEAF_GROUP_PREFIX, group_id) -} - -/// Lazy reader for BKD leaf file -#[derive(Clone)] -struct LazyIndexReader { - index_reader: Arc>>>, - store: Arc, - filename: String, -} - -impl LazyIndexReader { - fn new(store: Arc, filename: &str) -> Self { - Self { - index_reader: Arc::new(tokio::sync::Mutex::new(None)), - store, - filename: filename.to_string(), - } - } - - async fn get(&self) -> Result> { - let mut reader = self.index_reader.lock().await; - if reader.is_none() { - let r = self.store.open_index_file(&self.filename).await?; - *reader = Some(r); - } - Ok(reader.as_ref().unwrap().clone()) - } -} - -/// Cache key for BKD leaf nodes -#[derive(Debug, Clone)] -struct BKDLeafKey { - leaf_id: u32, -} - -impl CacheKey for BKDLeafKey { - type ValueType = CachedLeafData; - - fn key(&self) -> std::borrow::Cow<'_, str> { - format!("bkd-leaf-{}", self.leaf_id).into() - } -} - -/// Cached leaf data -#[derive(Debug, Clone)] -struct CachedLeafData(RecordBatch); - -impl DeepSizeOf for CachedLeafData { - fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { - // Approximate size of RecordBatch - self.0.get_array_memory_size() - } -} - -impl CachedLeafData { - fn new(batch: RecordBatch) -> Self { - Self(batch) - } - - fn into_inner(self) -> RecordBatch { - self.0 - } -} - -/// Geo index -pub struct GeoIndex { - data_type: DataType, - store: Arc, - fri: Option>, - index_cache: WeakLanceCache, - bkd_tree: Arc, - leaf_size: u32, -} - -impl std::fmt::Debug for GeoIndex { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("GeoIndex") - .field("data_type", &self.data_type) - .field("store", &self.store) - .field("fri", &self.fri) - .field("index_cache", &self.index_cache) - .field("bkd_tree", &self.bkd_tree) - .field("leaf_size", &self.leaf_size) - .finish() - } -} - -impl DeepSizeOf for GeoIndex { - fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { - self.bkd_tree.deep_size_of_children(context) + self.store.deep_size_of_children(context) - } -} - -impl GeoIndex { - /// Load the geo index from storage - async fn load( - store: Arc, - fri: Option>, - index_cache: &LanceCache, - ) -> Result> - where - Self: Sized, - { - // Load inner nodes - let inner_file = store.open_index_file(BKD_TREE_INNER_FILENAME).await?; - let inner_data = inner_file - .read_range(0..inner_file.num_rows(), None) - .await?; - - // Load leaf metadata - let leaf_file = store.open_index_file(BKD_TREE_LEAF_FILENAME).await?; - let leaf_data = leaf_file - .read_range(0..leaf_file.num_rows(), None) - .await?; - - // Deserialize tree structure from both files - let bkd_tree = BKDTreeLookup::from_record_batches(inner_data, leaf_data)?; - - // Extract metadata from inner file - let schema = inner_file.schema(); - let leaf_size = schema - .metadata - .get(LEAF_SIZE_META_KEY) - .and_then(|s| s.parse().ok()) - .unwrap_or(DEFAULT_LEAF_SIZE); - - // Get data type from schema - let data_type = schema.fields[0].data_type().clone(); - - Ok(Arc::new(Self { - data_type, - store, - fri, - index_cache: WeakLanceCache::from(index_cache), - bkd_tree: Arc::new(bkd_tree), - leaf_size, - })) - } - - /// Load a specific leaf from storage - async fn load_leaf( - &self, - leaf: &crate::scalar::bkd::BKDLeafNode, - metrics: &dyn MetricsCollector, - ) -> Result { - let file_id = leaf.file_id; - let row_offset = leaf.row_offset; - let num_rows = leaf.num_rows; - - // Use (file_id, row_offset) as cache key - // Combine file_id and row_offset into a single u32 (file_id should be small) - let cache_key = BKDLeafKey { leaf_id: file_id * 100_000 + (row_offset as u32) }; - let store = self.store.clone(); - - let cached = self - .index_cache - .get_or_insert_with_key(cache_key, move || async move { - metrics.record_part_load(); - - let filename = leaf_group_filename(file_id); - - // Open the leaf group file and read the specific row range - let reader = store.open_index_file(&filename).await?; - let batch = reader.read_range( - row_offset as usize..(row_offset + num_rows) as usize, - None - ).await?; - - Ok(CachedLeafData::new(batch)) - }) - .await?; - - Ok(cached.as_ref().clone().into_inner()) - } - - /// Search a specific leaf for points within the query bbox - async fn search_leaf( - &self, - leaf: &crate::scalar::bkd::BKDLeafNode, - query_bbox: [f64; 4], - metrics: &dyn MetricsCollector, - ) -> Result { - let leaf_data = self.load_leaf(leaf, metrics).await?; - - let file_id = leaf.file_id; - let row_offset = leaf.row_offset; - let num_rows = leaf.num_rows; - - println!( - "šŸ” Searching leaf (file={}, offset={}, rows={}) with {} points, query_bbox: {:?}", - file_id, - row_offset, - num_rows, - leaf_data.num_rows(), - query_bbox - ); - - // Filter points within this leaf - let mut row_ids = RowIdTreeMap::new(); - let x_array = leaf_data - .column(0) - .as_primitive::(); - let y_array = leaf_data - .column(1) - .as_primitive::(); - let row_id_array = leaf_data - .column(2) - .as_primitive::(); - - let mut matched_count = 0; - - // Debug: Check if SF (row_id=0) is in this leaf's actual data - let contains_sf = (0..leaf_data.num_rows()).any(|i| row_id_array.value(i) == 0); - if contains_sf { - println!(" šŸ”Ž Leaf (file={}, offset={}) contains SF (row_id=0)!", file_id, row_offset); - for i in 0..leaf_data.num_rows() { - if row_id_array.value(i) == 0 { - let x = x_array.value(i); - let y = y_array.value(i); - println!(" šŸŽÆ Found SF at index {}: ({}, {}) row_id=0", i, x, y); - println!(" šŸŽÆ point_in_bbox result: {}", point_in_bbox(x, y, &query_bbox)); - break; - } - } - } - - for i in 0..leaf_data.num_rows() { - let x = x_array.value(i); - let y = y_array.value(i); - let row_id = row_id_array.value(i); - - if point_in_bbox(x, y, &query_bbox) { - row_ids.insert(row_id); - matched_count += 1; - - // Log first few matches for debugging - if matched_count <= 3 { - println!( - " āœ… Match {}: point({}, {}) -> row_id {}", - matched_count, x, y, row_id - ); - } - } - } - - println!( - "šŸ“Š Leaf (file={}, offset={}) matched {} out of {} points", - file_id, - row_offset, - matched_count, - leaf_data.num_rows() - ); - - Ok(row_ids) - } -} - -#[async_trait] -impl Index for GeoIndex { - fn as_any(&self) -> &dyn Any { - self - } - - fn as_index(self: Arc) -> Arc { - self - } - - fn as_vector_index(self: Arc) -> Result> { - Err(Error::InvalidInput { - source: "GeoIndex is not a vector index".into(), - location: location!(), - }) - } - - async fn prewarm(&self) -> Result<()> { - Ok(()) - } - - fn statistics(&self) -> Result { - Ok(serde_json::json!({ - "type": "geo", - })) - } - - fn index_type(&self) -> IndexType { - IndexType::Geo - } - - async fn calculate_included_frags(&self) -> Result { - let frag_ids = RoaringBitmap::new(); - Ok(frag_ids) - } -} - -#[async_trait] -impl ScalarIndex for GeoIndex { - async fn search( - &self, - query: &dyn AnyQuery, - metrics: &dyn MetricsCollector, - ) -> Result { - let geo_query = query.as_any().downcast_ref::() - .ok_or_else(|| Error::InvalidInput { - source: "Geo index only supports GeoQuery".into(), - location: location!(), - })?; - - match geo_query { - GeoQuery::Intersects(min_x, min_y, max_x, max_y) => { - let query_bbox = [*min_x, *min_y, *max_x, *max_y]; - - println!( - "\nšŸ” Geo index search: st_intersects with bbox({}, {}, {}, {})", - min_x, min_y, max_x, max_y - ); - - // Step 1: Find intersecting leaves using in-memory tree traversal - let leaves = self.bkd_tree.find_intersecting_leaves(query_bbox)?; - - println!( - "šŸ“Š BKD tree traversal found {} intersecting leaves out of {} total leaves", - leaves.len(), - self.bkd_tree.num_leaves - ); - - // Step 2: Lazy-load and filter each leaf - let mut all_row_ids = RowIdTreeMap::new(); - - for leaf_node in &leaves { - let leaf_row_ids = self - .search_leaf(leaf_node, query_bbox, metrics) - .await?; - // Collect row IDs from the leaf and add them to the result set - let row_ids: Option> = leaf_row_ids.row_ids() - .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); - if let Some(row_ids) = row_ids { - all_row_ids.extend(row_ids); - } - } - - println!( - "āœ… Geo index searched {} leaves and returning {} row IDs\n", - leaves.len(), - all_row_ids.len().unwrap_or(0) - ); - - // We return Exact because we already filtered points in search_leaf - Ok(SearchResult::Exact(all_row_ids)) - } - } - } - - fn can_remap(&self) -> bool { - false - } - - /// Remap the row ids, creating a new remapped version of this index in `dest_store` - async fn remap( - &self, - _mapping: &HashMap>, - _dest_store: &dyn IndexStore, - ) -> Result { - Err(Error::InvalidInput { - source: "GeoIndex does not support remap".into(), - location: location!(), - }) - } - - /// Add the new data , creating an updated version of the index in `dest_store` - async fn update( - &self, - _new_data: SendableRecordBatchStream, - _dest_store: &dyn IndexStore, - ) -> Result { - Err(Error::InvalidInput { - source: "GeoIndex does not support update".into(), - location: location!(), - }) - } - - fn update_criteria(&self) -> UpdateCriteria { - UpdateCriteria::only_new_data( - TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), - ) - } - - fn derive_index_params(&self) -> Result { - let params = serde_json::to_value(GeoIndexBuilderParams::default())?; - Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::Geo).with_params(¶ms)) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeoIndexBuilderParams { - #[serde(default = "default_leaf_size")] - pub leaf_size: u32, -} - -fn default_leaf_size() -> u32 { - DEFAULT_LEAF_SIZE -} - -impl Default for GeoIndexBuilderParams { - fn default() -> Self { - Self { - leaf_size: default_leaf_size(), - } - } -} - -impl GeoIndexBuilderParams { - pub fn new() -> Self { - Self::default() - } - - pub fn with_leaf_size(mut self, leaf_size: u32) -> Self { - self.leaf_size = leaf_size; - self - } -} - -// A builder for geo index -pub struct GeoIndexBuilder { - options: GeoIndexBuilderParams, - items_type: DataType, - // Accumulated points: (x, y, row_id) - points: Vec<(f64, f64, u64)>, -} - -impl GeoIndexBuilder { - pub fn try_new(options: GeoIndexBuilderParams, items_type: DataType) -> Result { - Ok(Self { - options, - items_type, - points: Vec::new(), - }) - } - - pub async fn train(&mut self, batches_source: SendableRecordBatchStream) -> Result<()> { - assert!(batches_source.schema().field_with_name(ROW_ADDR).is_ok()); - - let mut batches_source = batches_source; - - while let Some(batch) = batches_source.try_next().await? { - // Extract GeoArrow point coordinates - let geom_array = batch.column(0).as_any().downcast_ref::() - .ok_or_else(|| Error::InvalidInput { - source: "Expected Struct array for GeoArrow data".into(), - location: location!(), - })?; - - let x_array = geom_array - .column(0) - .as_primitive::(); - let y_array = geom_array - .column(1) - .as_primitive::(); - let row_ids = batch - .column_by_name(ROW_ADDR) - .unwrap() - .as_primitive::(); - - for i in 0..batch.num_rows() { - self.points.push(( - x_array.value(i), - y_array.value(i), - row_ids.value(i), - )); - } - } - - log::debug!("Accumulated {} points for BKD tree", self.points.len()); - - Ok(()) - } - - pub async fn write_index(mut self, index_store: &dyn IndexStore) -> Result<()> { - if self.points.is_empty() { - // Write empty index files - self.write_empty_index(index_store).await?; - return Ok(()); - } - - // Build BKD tree - let (tree_nodes, leaf_batches) = self.build_bkd_tree()?; - - // Write tree structure to separate inner and leaf files - let (inner_batch, leaf_metadata_batch) = self.serialize_tree_nodes(&tree_nodes)?; - - // Write inner nodes - let mut inner_file = index_store - .new_index_file(BKD_TREE_INNER_FILENAME, inner_batch.schema()) - .await?; - inner_file.write_record_batch(inner_batch).await?; - inner_file - .finish_with_metadata(HashMap::from([( - LEAF_SIZE_META_KEY.to_string(), - self.options.leaf_size.to_string(), - )])) - .await?; - - // Write leaf metadata - let mut leaf_meta_file = index_store - .new_index_file(BKD_TREE_LEAF_FILENAME, leaf_metadata_batch.schema()) - .await?; - leaf_meta_file.write_record_batch(leaf_metadata_batch).await?; - leaf_meta_file.finish().await?; - - // Write actual leaf data grouped into files (multiple batches per file) - let leaf_schema = Arc::new(Schema::new(vec![ - Field::new("x", DataType::Float64, false), - Field::new("y", DataType::Float64, false), - Field::new(ROW_ID, DataType::UInt64, false), - ])); - - let num_groups = (leaf_batches.len() as u32 + DEFAULT_BATCHES_PER_LEAF_FILE - 1) / DEFAULT_BATCHES_PER_LEAF_FILE; - println!("šŸ“ Writing {} leaf batches into {} group files ({} batches per file)", - leaf_batches.len(), num_groups, DEFAULT_BATCHES_PER_LEAF_FILE); - - for group_id in 0..num_groups { - let start_idx = (group_id * DEFAULT_BATCHES_PER_LEAF_FILE) as usize; - let end_idx = ((group_id + 1) * DEFAULT_BATCHES_PER_LEAF_FILE).min(leaf_batches.len() as u32) as usize; - let group_batches = &leaf_batches[start_idx..end_idx]; - - let filename = leaf_group_filename(group_id); - println!(" Writing {}: {} batches ({} rows total)", - filename, - group_batches.len(), - group_batches.iter().map(|b| b.num_rows()).sum::()); - - let mut leaf_file = index_store - .new_index_file(&filename, leaf_schema.clone()) - .await?; - - for (batch_idx, leaf_batch) in group_batches.iter().enumerate() { - let batch_id = leaf_file.write_record_batch(leaf_batch.clone()).await?; - println!(" Batch {}: {} rows (batch_id={})", - start_idx + batch_idx, leaf_batch.num_rows(), batch_id); - } - - leaf_file.finish().await?; - } - println!("āœ… Finished writing {} group files\n", num_groups); - - log::debug!( - "Wrote BKD tree with {} nodes", - tree_nodes.len() - ); - - Ok(()) - } - - async fn write_empty_index(&self, index_store: &dyn IndexStore) -> Result<()> { - // Write empty inner node file - let inner_schema = crate::scalar::bkd::inner_node_schema(); - let empty_inner = RecordBatch::new_empty(inner_schema); - let mut inner_file = index_store - .new_index_file(BKD_TREE_INNER_FILENAME, empty_inner.schema()) - .await?; - inner_file.write_record_batch(empty_inner).await?; - inner_file.finish().await?; - - // Write empty leaf metadata file - let leaf_schema = crate::scalar::bkd::leaf_node_schema(); - let empty_leaf = RecordBatch::new_empty(leaf_schema); - let mut leaf_file = index_store - .new_index_file(BKD_TREE_LEAF_FILENAME, empty_leaf.schema()) - .await?; - leaf_file.write_record_batch(empty_leaf).await?; - leaf_file.finish().await?; - - // No actual leaf data files needed for empty index - - Ok(()) - } - - // Serialize tree nodes to separate inner and leaf RecordBatches - fn serialize_tree_nodes(&self, nodes: &[crate::scalar::bkd::BKDNode]) -> Result<(RecordBatch, RecordBatch)> { - use crate::scalar::bkd::BKDNode; - - // Separate inner and leaf nodes with their indices - let mut inner_nodes = Vec::new(); - let mut leaf_nodes = Vec::new(); - - for (idx, node) in nodes.iter().enumerate() { - match node { - BKDNode::Inner(_) => inner_nodes.push((idx as u32, node)), - BKDNode::Leaf(_) => leaf_nodes.push((idx as u32, node)), - } - } - - // Serialize inner nodes - let inner_batch = Self::serialize_inner_nodes(&inner_nodes)?; - - // Serialize leaf nodes - let leaf_batch = Self::serialize_leaf_nodes(&leaf_nodes)?; - - Ok((inner_batch, leaf_batch)) - } - - fn serialize_inner_nodes(nodes: &[(u32, &crate::scalar::bkd::BKDNode)]) -> Result { - use crate::scalar::bkd::BKDNode; - - let mut node_id_vals = Vec::with_capacity(nodes.len()); - let mut min_x_vals = Vec::with_capacity(nodes.len()); - let mut min_y_vals = Vec::with_capacity(nodes.len()); - let mut max_x_vals = Vec::with_capacity(nodes.len()); - let mut max_y_vals = Vec::with_capacity(nodes.len()); - let mut split_dim_vals = Vec::with_capacity(nodes.len()); - let mut split_value_vals = Vec::with_capacity(nodes.len()); - let mut left_child_vals = Vec::with_capacity(nodes.len()); - let mut right_child_vals = Vec::with_capacity(nodes.len()); - - for (idx, node) in nodes { - if let BKDNode::Inner(inner) = node { - node_id_vals.push(*idx); - min_x_vals.push(inner.bounds[0]); - min_y_vals.push(inner.bounds[1]); - max_x_vals.push(inner.bounds[2]); - max_y_vals.push(inner.bounds[3]); - split_dim_vals.push(inner.split_dim); - split_value_vals.push(inner.split_value); - left_child_vals.push(inner.left_child); - right_child_vals.push(inner.right_child); - } - } - - let schema = crate::scalar::bkd::inner_node_schema(); - - let columns: Vec = vec![ - Arc::new(UInt32Array::from(node_id_vals)), - Arc::new(Float64Array::from(min_x_vals)), - Arc::new(Float64Array::from(min_y_vals)), - Arc::new(Float64Array::from(max_x_vals)), - Arc::new(Float64Array::from(max_y_vals)), - Arc::new(UInt8Array::from(split_dim_vals)), - Arc::new(Float64Array::from(split_value_vals)), - Arc::new(UInt32Array::from(left_child_vals)), - Arc::new(UInt32Array::from(right_child_vals)), - ]; - - Ok(RecordBatch::try_new(schema, columns)?) - } - - fn serialize_leaf_nodes(nodes: &[(u32, &crate::scalar::bkd::BKDNode)]) -> Result { - use crate::scalar::bkd::BKDNode; - use arrow_array::UInt64Array; - - let mut node_id_vals = Vec::with_capacity(nodes.len()); - let mut min_x_vals = Vec::with_capacity(nodes.len()); - let mut min_y_vals = Vec::with_capacity(nodes.len()); - let mut max_x_vals = Vec::with_capacity(nodes.len()); - let mut max_y_vals = Vec::with_capacity(nodes.len()); - let mut file_id_vals = Vec::with_capacity(nodes.len()); - let mut row_offset_vals = Vec::with_capacity(nodes.len()); - let mut num_rows_vals = Vec::with_capacity(nodes.len()); - - for (idx, node) in nodes { - if let BKDNode::Leaf(leaf) = node { - node_id_vals.push(*idx); - min_x_vals.push(leaf.bounds[0]); - min_y_vals.push(leaf.bounds[1]); - max_x_vals.push(leaf.bounds[2]); - max_y_vals.push(leaf.bounds[3]); - file_id_vals.push(leaf.file_id); - row_offset_vals.push(leaf.row_offset); - num_rows_vals.push(leaf.num_rows); - } - } - - let schema = crate::scalar::bkd::leaf_node_schema(); - - let columns: Vec = vec![ - Arc::new(UInt32Array::from(node_id_vals)), - Arc::new(Float64Array::from(min_x_vals)), - Arc::new(Float64Array::from(min_y_vals)), - Arc::new(Float64Array::from(max_x_vals)), - Arc::new(Float64Array::from(max_y_vals)), - Arc::new(UInt32Array::from(file_id_vals)), - Arc::new(UInt64Array::from(row_offset_vals)), - Arc::new(UInt64Array::from(num_rows_vals)), - ]; - - Ok(RecordBatch::try_new(schema, columns)?) - } - - // Build BKD tree using the BKDTreeBuilder - fn build_bkd_tree(&mut self) -> Result<(Vec, Vec)> { - let builder = BKDTreeBuilder::new(self.options.leaf_size as usize); - builder.build(&mut self.points, DEFAULT_BATCHES_PER_LEAF_FILE) - } -} - -#[derive(Debug, Default)] -pub struct GeoIndexPlugin; - -impl GeoIndexPlugin { - async fn train_geo_index( - batches_source: SendableRecordBatchStream, - index_store: &dyn IndexStore, - options: Option, - ) -> Result<()> { - let value_type = batches_source.schema().field(0).data_type().clone(); - - let mut builder = GeoIndexBuilder::try_new(options.unwrap_or_default(), value_type)?; - - builder.train(batches_source).await?; - - builder.write_index(index_store).await?; - Ok(()) - } -} - -pub struct GeoIndexTrainingRequest { - pub params: GeoIndexBuilderParams, - pub criteria: TrainingCriteria, -} - -impl GeoIndexTrainingRequest { - pub fn new(params: GeoIndexBuilderParams) -> Self { - Self { - params, - criteria: TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), - } - } -} - -impl TrainingRequest for GeoIndexTrainingRequest { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn criteria(&self) -> &TrainingCriteria { - &self.criteria - } -} - -#[async_trait] -impl ScalarIndexPlugin for GeoIndexPlugin { - fn new_training_request( - &self, - params: &str, - field: &Field, - ) -> Result> { - // Check that the field is a Struct type - if !matches!(field.data_type(), DataType::Struct(_)) { - return Err(Error::InvalidInput { - source: "A geo index can only be created on a Struct field.".into(), - location: location!(), - }); - } - - // Check for GeoArrow metadata - let is_geoarrow = field - .metadata() - .get("ARROW:extension:name") - .map(|name| name.starts_with("geoarrow.")) - .unwrap_or(false); - - if !is_geoarrow { - return Err(Error::InvalidInput { - source: format!( - "Geo index requires GeoArrow metadata on field '{}'. \ - The field must have 'ARROW:extension:name' metadata starting with 'geoarrow.'", - field.name() - ) - .into(), - location: location!(), - }); - } - - let params = serde_json::from_str::(params)?; - - Ok(Box::new(GeoIndexTrainingRequest::new(params))) - } - - fn provides_exact_answer(&self) -> bool { - true // We do exact point-in-bbox filtering in search_leaf - } - - fn version(&self) -> u32 { - GEO_INDEX_VERSION - } - - fn new_query_parser( - &self, - index_name: String, - _index_details: &prost_types::Any, - ) -> Option> { - Some(Box::new(GeoQueryParser::new(index_name))) - } - - async fn train_index( - &self, - data: SendableRecordBatchStream, - index_store: &dyn IndexStore, - request: Box, - fragment_ids: Option>, - ) -> Result { - if fragment_ids.is_some() { - return Err(Error::InvalidInput { - source: "Geo index does not support fragment training".into(), - location: location!(), - }); - } - - let request = (request as Box) - .downcast::() - .map_err(|_| Error::InvalidInput { - source: "must provide training request created by new_training_request".into(), - location: location!(), - })?; - Self::train_geo_index(data, index_store, Some(request.params)).await?; - Ok(CreatedIndex { - index_details: prost_types::Any::from_msg(&pbold::GeoIndexDetails::default()) - .unwrap(), - index_version: GEO_INDEX_VERSION, - }) - } - - async fn load_index( - &self, - index_store: Arc, - _index_details: &prost_types::Any, - frag_reuse_index: Option>, - cache: &LanceCache, - ) -> Result> { - Ok(GeoIndex::load(index_store, frag_reuse_index, cache).await? as Arc) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - use arrow_array::RecordBatch; - use arrow_schema::{DataType, Field, Fields, Schema}; - use datafusion::execution::SendableRecordBatchStream; - use datafusion::physical_plan::stream::RecordBatchStreamAdapter; - use futures::stream; - use lance_core::cache::LanceCache; - use lance_core::utils::tempfile::TempObjDir; - use lance_core::ROW_ADDR; - use lance_io::object_store::ObjectStore; - - use crate::scalar::lance_format::LanceIndexStore; - - #[tokio::test] - async fn test_empty_geo_index() { - let tmpdir = TempObjDir::default(); - let test_store = Arc::new(LanceIndexStore::new( - Arc::new(ObjectStore::local()), - tmpdir.clone(), - Arc::new(LanceCache::no_cache()), - )); - - let data = arrow_array::StructArray::from(vec![]); - let row_ids = arrow_array::UInt64Array::from(Vec::::new()); - let fields: Fields = Vec::::new().into(); - let schema = Arc::new(Schema::new(vec![ - Field::new("value", DataType::Struct(fields), false), - Field::new(ROW_ADDR, DataType::UInt64, false), - ])); - let data = - RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap(); - - let data_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - schema, - stream::once(std::future::ready(Ok(data))), - )); - - GeoIndexPlugin::train_geo_index(data_stream, test_store.as_ref(), None) - .await - .unwrap(); - - // Read the index file back and check its contents - let _index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) - .await - .expect("Failed to load GeoIndex"); - } -} - diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 612eaf471dd..016e92ee438 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -15,7 +15,7 @@ use crate::{ frag_reuse::FragReuseIndex, scalar::{ bitmap::BitmapIndexPlugin, bloomfilter::BloomFilterIndexPlugin, btree::BTreeIndexPlugin, - expression::ScalarQueryParser, geoindex::GeoIndexPlugin, inverted::InvertedIndexPlugin, + expression::ScalarQueryParser, geo::GeoIndexPlugin, inverted::InvertedIndexPlugin, json::JsonIndexPlugin, label_list::LabelListIndexPlugin, ngram::NGramIndexPlugin, zonemap::ZoneMapIndexPlugin, CreatedIndex, IndexStore, ScalarIndex, }, From feede2942990ef883c238ce69f12d9554a944614 Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Thu, 16 Oct 2025 11:18:37 -0700 Subject: [PATCH 5/7] add functional benchmarl --- rust/lance-index/Cargo.toml | 7 +- rust/lance-index/benches/geoindex.rs | 251 ++++++++++++ rust/lance-index/src/scalar/geo/bkd.rs | 41 +- rust/lance-index/src/scalar/geo/geoindex.rs | 425 ++++++++++++++++++-- 4 files changed, 650 insertions(+), 74 deletions(-) create mode 100644 rust/lance-index/benches/geoindex.rs diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 605e40aa792..6a2220fdbef 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -98,9 +98,6 @@ pprof.workspace = true # docs.rs uses an older version of Ubuntu that does not have the necessary protoc version features = ["protoc"] -[profile.bench] -debug = true - [[bench]] name = "find_partitions" harness = false @@ -149,5 +146,9 @@ harness = false name = "rq" harness = false +[[bench]] +name = "geoindex" +harness = false + [lints] workspace = true diff --git a/rust/lance-index/benches/geoindex.rs b/rust/lance-index/benches/geoindex.rs new file mode 100644 index 00000000000..a5a85a9b6c5 --- /dev/null +++ b/rust/lance-index/benches/geoindex.rs @@ -0,0 +1,251 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmarks for geo index vs brute force scanning + +use arrow_schema::{DataType, Field}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use lance_core::cache::LanceCache; +use lance_core::utils::tempfile::TempObjDir; +use lance_io::object_store::ObjectStore; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; +use std::sync::Arc; + +use lance_index::scalar::geo::geoindex::{GeoIndex, GeoIndexBuilder, GeoIndexBuilderParams}; +use lance_index::scalar::lance_format::LanceIndexStore; +use lance_index::scalar::{GeoQuery, ScalarIndex}; +use lance_index::metrics::{MetricsCollector, NoOpMetricsCollector}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +struct LeafCounter { + leaves_visited: AtomicUsize, +} + +impl LeafCounter { + fn new() -> Self { + Self { + leaves_visited: AtomicUsize::new(0), + } + } + + fn get_count(&self) -> usize { + self.leaves_visited.load(Ordering::Relaxed) + } + + fn reset(&self) { + self.leaves_visited.store(0, Ordering::Relaxed); + } +} + +impl MetricsCollector for LeafCounter { + fn record_parts_loaded(&self, num_parts: usize) { + self.leaves_visited.fetch_add(num_parts, Ordering::Relaxed); + } + + fn record_index_loads(&self, _num_indexes: usize) {} + fn record_comparisons(&self, _num_comparisons: usize) {} +} + +fn create_test_points(num_points: usize) -> Vec<(f64, f64, u64)> { + let mut rng = StdRng::seed_from_u64(42); + let mut points = Vec::with_capacity(num_points); + + for i in 0..num_points { + let x = rng.random_range(0.0..1000.0); + let y = rng.random_range(0.0..1000.0); + points.push((x, y, i as u64)); + } + + points +} + +async fn create_geo_index(points: Vec<(f64, f64, u64)>) -> (Arc, Arc, TempObjDir) { + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 1000, + batches_per_file: 10, + }; + + let mut builder = GeoIndexBuilder::try_new( + params, + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + + builder.points = points; + builder.write_index(store.as_ref()).await.unwrap(); + + let index = GeoIndex::load(store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + (index, store, tmpdir) +} + +fn bench_geo_index_intersects_pruning(c: &mut Criterion) { + let mut group = c.benchmark_group("geo_intersects_pruning_vs_scan_all"); + group.sample_size(10); // Reduce samples for faster benchmarks + group.measurement_time(std::time::Duration::from_secs(60)); + + // Compares BKD tree spatial pruning vs scanning all leaves for intersects queries + // Both approaches do lazy loading - we're measuring pruning efficiency + + // Test with different dataset sizes + for num_points in [10_000, 100_000, 1_000_000] { + let points = create_test_points(num_points); + + // Create index (do this once outside the benchmark) + let rt = tokio::runtime::Runtime::new().unwrap(); + let (index, _store, _tmpdir) = rt.block_on(create_geo_index(points.clone())); + + // Generate random queries + let mut rng = StdRng::seed_from_u64(42); + let queries: Vec<[f64; 4]> = (0..10) + .map(|_| { + let width = rng.random_range(10.0..50.0); + let height = rng.random_range(10.0..50.0); + let min_x = rng.random_range(0.0..(1000.0 - width)); + let min_y = rng.random_range(0.0..(1000.0 - height)); + [min_x, min_y, min_x + width, min_y + height] + }) + .collect(); + + // Benchmark: scan all leaves (no spatial pruning) + let scan_all_counter = Arc::new(LeafCounter::new()); + group.bench_with_input( + BenchmarkId::new("scan_all_leaves", num_points), + &(&index, &queries, scan_all_counter.clone()), + |b, (index, queries, counter)| { + let rt = tokio::runtime::Runtime::new().unwrap(); + b.iter(|| { + counter.reset(); + rt.block_on(async { + for query_bbox in queries.iter() { + let _result = index.search_all_leaves(*query_bbox, counter.as_ref()).await.unwrap(); + } + }); + }); + }, + ); + + // Benchmark: BKD tree with spatial pruning + let pruned_counter = Arc::new(LeafCounter::new()); + group.bench_with_input( + BenchmarkId::new("with_pruning", num_points), + &(&index, &queries, pruned_counter.clone()), + |b, (index, queries, counter)| { + let rt = tokio::runtime::Runtime::new().unwrap(); + b.iter(|| { + counter.reset(); + rt.block_on(async { + for query_bbox in queries.iter() { + let query = GeoQuery::Intersects(query_bbox[0], query_bbox[1], query_bbox[2], query_bbox[3]); + let _result = index.search(&query, counter.as_ref()).await.unwrap(); + } + }); + }); + }, + ); + + // Print statistics with speedup estimate + let scan_avg = scan_all_counter.get_count() as f64 / queries.len() as f64; + let pruned_avg = pruned_counter.get_count() as f64 / queries.len() as f64; + + // Quick timing for speedup calculation + let rt = tokio::runtime::Runtime::new().unwrap(); + let scan_start = std::time::Instant::now(); + for query_bbox in queries.iter() { + rt.block_on(async { + let _ = index.search_all_leaves(*query_bbox, &NoOpMetricsCollector).await; + }); + } + let scan_time = scan_start.elapsed().as_secs_f64() / queries.len() as f64; + + let prune_start = std::time::Instant::now(); + for query_bbox in queries.iter() { + rt.block_on(async { + let query = GeoQuery::Intersects(query_bbox[0], query_bbox[1], query_bbox[2], query_bbox[3]); + let _ = index.search(&query, &NoOpMetricsCollector).await; + }); + } + let prune_time = prune_start.elapsed().as_secs_f64() / queries.len() as f64; + + println!( + "\n {} points: {} leaves | scan_all: {:.1} leaves/query ({:.1}ms) | with_pruning: {:.1} leaves/query ({:.1}ms) | {:.1}x faster, {:.1}% reduction", + num_points, + index.num_leaves(), + scan_avg, + scan_time * 1000.0, + pruned_avg, + prune_time * 1000.0, + scan_time / prune_time, + 100.0 * (1.0 - pruned_avg / scan_avg) + ); + } + + group.finish(); +} + +fn bench_geo_intersects_query_size(c: &mut Criterion) { + let mut group = c.benchmark_group("geo_intersects_query_size"); + group.sample_size(10); + + // Fixed dataset size, varying intersects query box sizes + let points = create_test_points(1_000_000); + let rt = tokio::runtime::Runtime::new().unwrap(); + let (index, _store, _tmpdir) = rt.block_on(create_geo_index(points)); + + let mut rng = StdRng::seed_from_u64(42); + + // Test different query sizes + for query_size in [1.0, 10.0, 50.0, 100.0, 200.0] { + let queries: Vec<[f64; 4]> = (0..10) + .map(|_| { + let min_x = rng.random_range(0.0..(1000.0 - query_size)); + let min_y = rng.random_range(0.0..(1000.0 - query_size)); + [min_x, min_y, min_x + query_size, min_y + query_size] + }) + .collect(); + + let leaf_counter = Arc::new(LeafCounter::new()); + group.bench_with_input( + BenchmarkId::new("query_size", format!("{}x{}", query_size, query_size)), + &(&index, &queries, leaf_counter.clone()), + |b, (index, queries, counter)| { + let rt = tokio::runtime::Runtime::new().unwrap(); + b.iter(|| { + counter.reset(); + rt.block_on(async { + for query_bbox in queries.iter() { + let query = GeoQuery::Intersects(query_bbox[0], query_bbox[1], query_bbox[2], query_bbox[3]); + let _result = index.search(&query, counter.as_ref()).await.unwrap(); + } + }); + }); + }, + ); + + let avg_leaves = leaf_counter.get_count() as f64 / queries.len() as f64; + println!( + " Query {}x{}: avg {:.1} leaves visited (out of {} total, {:.1}% selectivity)", + query_size, + query_size, + avg_leaves, + index.num_leaves(), + 100.0 * avg_leaves / index.num_leaves() as f64 + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_geo_index_intersects_pruning, bench_geo_intersects_query_size); +criterion_main!(benches); + diff --git a/rust/lance-index/src/scalar/geo/bkd.rs b/rust/lance-index/src/scalar/geo/bkd.rs index 948c27afc70..ea5807c1806 100644 --- a/rust/lance-index/src/scalar/geo/bkd.rs +++ b/rust/lance-index/src/scalar/geo/bkd.rs @@ -15,7 +15,7 @@ //! - Groups points into leaves of configurable size //! - Stores tree structure separately from leaf data for lazy loading -use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch, UInt64Array, UInt32Array, UInt8Array}; +use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt64Array}; use arrow_array::cast::AsArray; use arrow_schema::{DataType, Field, Schema}; use deepsize::DeepSizeOf; @@ -180,12 +180,6 @@ impl BKDTreeLookup { } } - println!( - "🌲 Tree traversal: visited {} nodes, found {} intersecting leaves", - nodes_visited, - leaves.len() - ); - Ok(leaves) } @@ -332,18 +326,6 @@ impl BKDTreeBuilder { return Ok((vec![], vec![])); } - println!( - "\nšŸ—ļø Building BKD tree for {} points with leaf size {}", - points.len(), - self.leaf_size - ); - - // Log first few points for debugging - println!("šŸ“ First 5 points:"); - for i in 0..std::cmp::min(5, points.len()) { - println!(" Point {}: x={}, y={}, row_id={}", i, points[i].0, points[i].1, points[i].2); - } - let mut leaf_counter = 0u32; let mut all_nodes = Vec::new(); let mut all_leaf_batches = Vec::new(); @@ -390,12 +372,6 @@ impl BKDTreeBuilder { } } - println!( - "āœ… Built BKD tree: {} nodes ({} leaves)\n", - all_nodes.len(), - leaf_counter - ); - Ok((all_nodes, all_leaf_batches)) } @@ -431,12 +407,6 @@ impl BKDTreeBuilder { num_rows, })); - // Debug: Check if SF (row_id=0) is in this leaf - if points.iter().any(|(_, _, rid)| *rid == 0) { - println!("šŸŽÆ SF (row_id=0) in leaf node_id={}, leaf_id={}, num_rows={}, bounds=[{}, {}, {}, {}]", - node_id, leaf_id, num_rows, min_x, min_y, max_x, max_y); - } - return Ok(node_id); } @@ -466,15 +436,6 @@ impl BKDTreeBuilder { // Calculate bounds for this node (before splitting the slice) let (min_x, min_y, max_x, max_y) = calculate_bounds(points); - // Debug: Log first inner node to verify bounds - if all_nodes.is_empty() { - println!("šŸ” Root node bounds: [{}, {}, {}, {}]", min_x, min_y, max_x, max_y); - println!(" Split on dim {} at value {}", split_dim, split_value); - println!(" Contains SF (-122.4194, 37.7749)? x_ok={}, y_ok={}", - min_x <= -122.4194 && -122.4194 <= max_x, - min_y <= 37.7749 && 37.7749 <= max_y); - } - // Reserve space for this inner node (placeholder - we'll update it after building children) let node_id = all_nodes.len() as u32; all_nodes.push(BKDNode::Inner(BKDInnerNode { diff --git a/rust/lance-index/src/scalar/geo/geoindex.rs b/rust/lance-index/src/scalar/geo/geoindex.rs index e12d6cda3c1..f432f37529c 100644 --- a/rust/lance-index/src/scalar/geo/geoindex.rs +++ b/rust/lance-index/src/scalar/geo/geoindex.rs @@ -158,7 +158,7 @@ impl DeepSizeOf for GeoIndex { impl GeoIndex { /// Load the geo index from storage - async fn load( + pub async fn load( store: Arc, fri: Option>, index_cache: &LanceCache, @@ -238,6 +238,34 @@ impl GeoIndex { Ok(cached.as_ref().clone().into_inner()) } + /// Get the number of leaves in the BKD tree (useful for benchmarking) + pub fn num_leaves(&self) -> usize { + self.bkd_tree.num_leaves as usize + } + + /// Search all leaves without using BKD tree pruning (useful for benchmarking) + pub async fn search_all_leaves( + &self, + query_bbox: [f64; 4], + metrics: &dyn MetricsCollector, + ) -> Result { + let mut all_row_ids = RowIdTreeMap::new(); + + // Iterate through all nodes and search every leaf + for node in self.bkd_tree.nodes.iter() { + if let BKDNode::Leaf(leaf) = node { + let leaf_row_ids = self.search_leaf(leaf, query_bbox, metrics).await?; + let row_ids: Option> = leaf_row_ids.row_ids() + .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); + if let Some(row_ids) = row_ids { + all_row_ids.extend(row_ids); + } + } + } + + Ok(all_row_ids) + } + /// Search a specific leaf for points within the query bbox async fn search_leaf( &self, @@ -247,10 +275,6 @@ impl GeoIndex { ) -> Result { let leaf_data = self.load_leaf(leaf, metrics).await?; - let file_id = leaf.file_id; - let row_offset = leaf.row_offset; - let num_rows = leaf.num_rows; - // Filter points within this leaf let mut row_ids = RowIdTreeMap::new(); let x_array = leaf_data @@ -295,6 +319,7 @@ impl Index for GeoIndex { } async fn prewarm(&self) -> Result<()> { + // No-op: geo index uses lazy loading Ok(()) } @@ -436,7 +461,7 @@ pub struct GeoIndexBuilder { options: GeoIndexBuilderParams, items_type: DataType, // Accumulated points: (x, y, row_id) - points: Vec<(f64, f64, u64)>, + pub points: Vec<(f64, f64, u64)>, } impl GeoIndexBuilder { @@ -808,14 +833,9 @@ mod tests { use super::*; use std::sync::Arc; - use arrow_array::RecordBatch; - use arrow_schema::{DataType, Field, Fields, Schema}; - use datafusion::execution::SendableRecordBatchStream; - use datafusion::physical_plan::stream::RecordBatchStreamAdapter; - use futures::stream; + use arrow_schema::{DataType, Field}; use lance_core::cache::LanceCache; use lance_core::utils::tempfile::TempObjDir; - use lance_core::ROW_ADDR; use lance_io::object_store::ObjectStore; use crate::scalar::lance_format::LanceIndexStore; @@ -891,7 +911,7 @@ mod tests { visited: &mut Vec, leaf_count: &mut u32, parent_bounds: Option<[f64; 4]>, - parent_split: Option<(u8, f64, bool)>, // (split_dim, split_value, is_left_child) from parent + _parent_split: Option<(u8, f64, bool)>, // (split_dim, split_value, is_left_child) from parent metrics: &dyn MetricsCollector, ) -> Result<()> { let node_idx = node_id as usize; @@ -1036,7 +1056,7 @@ mod tests { } #[tokio::test] - async fn test_geo_index_with_custom_max_points_per_leaf() { + async fn test_geo_intersects_with_custom_max_points_per_leaf() { use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; @@ -1115,7 +1135,7 @@ mod tests { } #[tokio::test] - async fn test_geo_index_query_correctness_various_configs() { + async fn test_geo_intersects_query_correctness_various_configs() { use crate::metrics::NoOpMetricsCollector; // Test query correctness with different configurations @@ -1183,7 +1203,7 @@ mod tests { } #[tokio::test] - async fn test_geo_index_single_leaf() { + async fn test_geo_intersects_single_leaf() { // Edge case: all points fit in single leaf let test_store = create_test_store(); @@ -1225,7 +1245,7 @@ mod tests { } #[tokio::test] - async fn test_geo_index_many_small_leaves() { + async fn test_geo_intersects_many_small_leaves() { // Stress test: many small leaves, test file grouping let test_store = create_test_store(); @@ -1429,10 +1449,8 @@ mod tests { } #[tokio::test] - async fn test_geo_index_lazy_loading() { + async fn test_geo_intersects_lazy_loading() { use crate::metrics::NoOpMetricsCollector; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; // Test that leaves are loaded lazily (not all at once) let test_store = create_test_store(); @@ -1524,7 +1542,7 @@ mod tests { #[tokio::test] #[ignore] // Expensive test - run with: cargo test -- --ignored - async fn test_geo_index_large_scale_lazy_loading() { + async fn test_geo_intersects_large_scale_lazy_loading() { use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -1568,7 +1586,6 @@ mod tests { // Run 100 random queries with load tracking let metrics = LoadTracker { part_loads: AtomicUsize::new(0) }; - let mut total_results = 0; let mut total_leaves_touched = 0; for _i in 0..100 { @@ -1588,15 +1605,7 @@ mod tests { total_leaves_touched += intersecting_leaves.len(); // Execute query - let result = index.search(&query, &metrics).await.unwrap(); - - match result { - crate::scalar::SearchResult::Exact(row_ids) => { - let count = row_ids.len().unwrap_or(0); - total_results += count; - } - _ => panic!("Expected Exact search result"), - } + let _result = index.search(&query, &metrics).await.unwrap(); } let total_io_ops = metrics.part_loads.load(Ordering::Relaxed); @@ -1673,5 +1682,359 @@ mod tests { } } } + + #[tokio::test] + async fn test_geo_index_duplicate_coordinates_different_row_ids() { + // Test that duplicate coordinates with different row_ids are all stored correctly + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 3, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create multiple points at the same coordinates with different row_ids + // This simulates real-world scenarios where multiple records have the same location + for i in 0..20 { + // 5 points at location (10.0, 10.0) with different row_ids + builder.points.push((10.0, 10.0, i as u64)); + } + + // Add some other points at different locations + for i in 20..50 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Query the region containing the duplicates + let query = GeoQuery::Intersects(9.0, 9.0, 11.0, 11.0); + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + // Should find all 20 points at (10.0, 10.0) + assert_eq!(row_ids.len().unwrap_or(0), 20, + "Expected 20 duplicate coordinate entries"); + + // Verify all row_ids 0..19 are present + for i in 0..20 { + assert!(row_ids.contains(i as u64), + "Missing row_id {} for duplicate coordinate", i); + } + } + _ => panic!("Expected Exact search result"), + } + + // Verify data integrity: all 50 points should be stored + let mut all_row_ids = Vec::new(); + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + assert_eq!(all_row_ids.len(), 50, "Should store all 50 points including duplicates"); + + let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); + assert_eq!(unique_row_ids.len(), 50, "All 50 row_ids should be unique"); + } + + #[tokio::test] + async fn test_geo_index_many_duplicates_at_same_location() { + // Test with many duplicate coordinates to ensure they don't cause issues + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 20, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 100 points all at the exact same location + for i in 0..100 { + builder.points.push((50.0, 50.0, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect("BKD tree validation failed with many duplicates"); + + // Query should return all 100 points + let query = GeoQuery::Intersects(49.0, 49.0, 51.0, 51.0); + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!(row_ids.len().unwrap_or(0), 100, + "Expected all 100 duplicate coordinate entries"); + + // Verify all row_ids are present + for i in 0..100 { + assert!(row_ids.contains(i as u64), + "Missing row_id {} in duplicate set", i); + } + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_index_duplicates_across_multiple_locations() { + // Test duplicates at multiple different locations + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 15, + batches_per_file: 3, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create clusters of duplicates at different locations + let locations = vec![ + (10.0, 10.0), + (20.0, 20.0), + (30.0, 30.0), + (40.0, 40.0), + ]; + + let mut row_id = 0u64; + for (x, y) in &locations { + // 10 duplicates at each location + for _ in 0..10 { + builder.points.push((*x, *y, row_id)); + row_id += 1; + } + } + + // Add some unique points + for i in 0..20 { + builder.points.push((i as f64, i as f64 + 50.0, row_id)); + row_id += 1; + } + + let total_points = row_id; + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect("BKD tree validation failed with duplicates at multiple locations"); + + let metrics = NoOpMetricsCollector {}; + + // Query each location cluster + for (i, (x, y)) in locations.iter().enumerate() { + let query = GeoQuery::Intersects(x - 1.0, y - 1.0, x + 1.0, y + 1.0); + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!(row_ids.len().unwrap_or(0), 10, + "Expected 10 duplicates at location {}", i); + + // Verify the correct row_ids for this cluster + let expected_start = (i * 10) as u64; + let expected_end = expected_start + 10; + for expected_id in expected_start..expected_end { + assert!(row_ids.contains(expected_id), + "Missing row_id {} in cluster at ({}, {})", expected_id, x, y); + } + } + _ => panic!("Expected Exact search result"), + } + } + + // Verify total count + let mut all_row_ids = Vec::new(); + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + assert_eq!(all_row_ids.len(), total_points as usize, + "Should store all {} points including duplicates", total_points); + } + + #[tokio::test] + async fn test_geo_index_duplicate_handling_with_leaf_splits() { + // Test that duplicates are handled correctly when they cause leaf splits + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, // Small leaf size to force splits + batches_per_file: 3, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 50 points at the same location - should span multiple leaves + for i in 0..50 { + builder.points.push((25.0, 25.0, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect("BKD tree validation failed with duplicate-induced splits"); + + // Query should return all 50 points + let query = GeoQuery::Intersects(24.0, 24.0, 26.0, 26.0); + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!(row_ids.len().unwrap_or(0), 50, + "Expected all 50 duplicate entries after leaf splits"); + + // Verify all row_ids are present + for i in 0..50 { + assert!(row_ids.contains(i as u64), + "Missing row_id {} after leaf split", i); + } + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_index_mixed_duplicates_and_unique_points() { + // Realistic test: mix of unique points and duplicates + use crate::metrics::NoOpMetricsCollector; + use rand::{Rng, SeedableRng}; + use rand::rngs::StdRng; + + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 20, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + let mut rng = StdRng::seed_from_u64(123); + let mut row_id = 0u64; + + // Add 100 random unique points + for _ in 0..100 { + let x = rng.random_range(0.0..100.0); + let y = rng.random_range(0.0..100.0); + builder.points.push((x, y, row_id)); + row_id += 1; + } + + // Add 20 duplicates at a specific location + for _ in 0..20 { + builder.points.push((50.0, 50.0, row_id)); + row_id += 1; + } + + // Add 50 more random unique points + for _ in 0..50 { + let x = rng.random_range(0.0..100.0); + let y = rng.random_range(0.0..100.0); + builder.points.push((x, y, row_id)); + row_id += 1; + } + + let total_points = 170; + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index).await + .expect("BKD tree validation failed with mixed duplicates and unique points"); + + let metrics = NoOpMetricsCollector {}; + + // Query the duplicate cluster + let query = GeoQuery::Intersects(49.0, 49.0, 51.0, 51.0); + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!(row_ids.len().unwrap_or(0), 20, + "Expected 20 duplicate entries at (50, 50)"); + + // Verify the duplicate row_ids (100..119) + for expected_id in 100..120 { + assert!(row_ids.contains(expected_id as u64), + "Missing row_id {} in duplicate cluster", expected_id); + } + } + _ => panic!("Expected Exact search result"), + } + + // Verify total count + let mut all_row_ids = Vec::new(); + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + assert_eq!(all_row_ids.len(), total_points, + "Should store all {} points", total_points); + + // Verify all row_ids are unique (no double-counting) + let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); + assert_eq!(unique_row_ids.len(), total_points, + "All row_ids should be unique"); + } } From e16b53a7de2c659101a180817862fb08a767929d Mon Sep 17 00:00:00 2001 From: jaystarshot Date: Thu, 16 Oct 2025 12:26:36 -0700 Subject: [PATCH 6/7] Add some edge case testing --- rust/lance-index/src/scalar/geo/bkd.rs | 6 +- rust/lance-index/src/scalar/geo/geoindex.rs | 391 +++++++++++++++++--- 2 files changed, 340 insertions(+), 57 deletions(-) diff --git a/rust/lance-index/src/scalar/geo/bkd.rs b/rust/lance-index/src/scalar/geo/bkd.rs index ea5807c1806..096324732ef 100644 --- a/rust/lance-index/src/scalar/geo/bkd.rs +++ b/rust/lance-index/src/scalar/geo/bkd.rs @@ -151,7 +151,7 @@ impl BKDTreeLookup { pub fn find_intersecting_leaves(&self, query_bbox: [f64; 4]) -> Result> { let mut leaves = Vec::new(); let mut stack = vec![self.root_id]; - let mut nodes_visited = 0; + let mut _nodes_visited = 0; while let Some(node_id) = stack.pop() { if node_id as usize >= self.nodes.len() { @@ -159,7 +159,7 @@ impl BKDTreeLookup { } let node = &self.nodes[node_id as usize]; - nodes_visited += 1; + _nodes_visited += 1; // Check if node's bounding box intersects with query bbox let intersects = bboxes_intersect(&node.bounds(), &query_bbox); @@ -388,7 +388,7 @@ impl BKDTreeBuilder { // Base case: create leaf node if points.len() <= self.leaf_size { let node_id = all_nodes.len() as u32; - let leaf_id = *leaf_counter; + let _leaf_id = *leaf_counter; *leaf_counter += 1; // Calculate bounding box for this leaf diff --git a/rust/lance-index/src/scalar/geo/geoindex.rs b/rust/lance-index/src/scalar/geo/geoindex.rs index f432f37529c..d860ebc609a 100644 --- a/rust/lance-index/src/scalar/geo/geoindex.rs +++ b/rust/lance-index/src/scalar/geo/geoindex.rs @@ -39,7 +39,7 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; use std::{collections::HashMap, sync::Arc}; -use crate::scalar::{AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; +use crate::scalar::{AnyQuery, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; use crate::scalar::FragReuseIndex; use crate::vector::VectorIndex; use crate::{Index, IndexType}; @@ -65,33 +65,6 @@ fn leaf_group_filename(group_id: u32) -> String { format!("{}{}.lance", LEAF_GROUP_PREFIX, group_id) } -/// Lazy reader for BKD leaf file -#[derive(Clone)] -struct LazyIndexReader { - index_reader: Arc>>>, - store: Arc, - filename: String, -} - -impl LazyIndexReader { - fn new(store: Arc, filename: &str) -> Self { - Self { - index_reader: Arc::new(tokio::sync::Mutex::new(None)), - store, - filename: filename.to_string(), - } - } - - async fn get(&self) -> Result> { - let mut reader = self.index_reader.lock().await; - if reader.is_none() { - let r = self.store.open_index_file(&self.filename).await?; - *reader = Some(r); - } - Ok(reader.as_ref().unwrap().clone()) - } -} - /// Cache key for BKD leaf nodes #[derive(Debug, Clone)] struct BKDLeafKey { @@ -243,7 +216,7 @@ impl GeoIndex { self.bkd_tree.num_leaves as usize } - /// Search all leaves without using BKD tree pruning (useful for benchmarking) + /// Search all leaves without using BKD tree pruning (only useful for benchmarking) pub async fn search_all_leaves( &self, query_bbox: [f64; 4], @@ -275,29 +248,28 @@ impl GeoIndex { ) -> Result { let leaf_data = self.load_leaf(leaf, metrics).await?; - // Filter points within this leaf - let mut row_ids = RowIdTreeMap::new(); - let x_array = leaf_data - .column(0) - .as_primitive::(); - let y_array = leaf_data - .column(1) - .as_primitive::(); - let row_id_array = leaf_data - .column(2) - .as_primitive::(); - - for i in 0..leaf_data.num_rows() { - let x = x_array.value(i); - let y = y_array.value(i); - let row_id = row_id_array.value(i); - - if point_in_bbox(x, y, &query_bbox) { - row_ids.insert(row_id); - } - } + // Filter points within this leaf using iterators + let x_array = leaf_data.column(0).as_primitive::(); + let y_array = leaf_data.column(1).as_primitive::(); + let row_id_array = leaf_data.column(2).as_primitive::(); + + let row_ids: Vec = x_array + .iter() + .zip(y_array.iter()) + .zip(row_id_array.iter()) + .filter_map(|((x_opt, y_opt), row_id_opt)| { + match (x_opt, y_opt, row_id_opt) { + (Some(x), Some(y), Some(row_id)) if point_in_bbox(x, y, &query_bbox) => { + Some(row_id) + } + _ => None, + } + }) + .collect(); - Ok(row_ids) + let mut row_id_map = RowIdTreeMap::new(); + row_id_map.extend(row_ids); + Ok(row_id_map) } } @@ -459,7 +431,7 @@ impl GeoIndexBuilderParams { // A builder for geo index pub struct GeoIndexBuilder { options: GeoIndexBuilderParams, - items_type: DataType, + _items_type: DataType, // Accumulated points: (x, y, row_id) pub points: Vec<(f64, f64, u64)>, } @@ -468,7 +440,7 @@ impl GeoIndexBuilder { pub fn try_new(options: GeoIndexBuilderParams, items_type: DataType) -> Result { Ok(Self { options, - items_type, + _items_type: items_type, points: Vec::new(), }) } @@ -516,6 +488,22 @@ impl GeoIndexBuilder { return Ok(()); } + // Validate coordinates (reject NaN and Infinity like Lucene does) + for (x, y, row_id) in &self.points { + if x.is_nan() || y.is_nan() { + return Err(Error::InvalidInput { + source: format!("Cannot index NaN coordinates (row_id={})", row_id).into(), + location: location!(), + }); + } + if x.is_infinite() || y.is_infinite() { + return Err(Error::InvalidInput { + source: format!("Cannot index Infinite coordinates (row_id={})", row_id).into(), + location: location!(), + }); + } + } + // Build BKD tree let (tree_nodes, leaf_batches) = self.build_bkd_tree()?; @@ -1541,7 +1529,302 @@ mod tests { } #[tokio::test] - #[ignore] // Expensive test - run with: cargo test -- --ignored + async fn test_geo_intersects_invalid_coordinates() { + // Test that NaN and infinity coordinates are rejected (like Lucene does) + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + // Test NaN in X coordinate + let mut builder = GeoIndexBuilder::try_new( + params.clone(), + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((f64::NAN, 20.0, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("NaN")); + + // Test NaN in Y coordinate + let mut builder = GeoIndexBuilder::try_new( + params.clone(), + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((20.0, f64::NAN, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("NaN")); + + // Test Infinity in X coordinate + let mut builder = GeoIndexBuilder::try_new( + params.clone(), + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((f64::INFINITY, 20.0, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Infinite")); + + // Test Negative Infinity in Y coordinate + let mut builder = GeoIndexBuilder::try_new( + params.clone(), + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((20.0, f64::NEG_INFINITY, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Infinite")); + } + + #[tokio::test] + async fn test_geo_intersects_out_of_bounds_queries() { + use crate::metrics::NoOpMetricsCollector; + + // Test queries completely outside the data bounds + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new( + params, + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + + // Add points in region 0-100, 0-100 + for i in 0..50 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Query completely out of bounds (far away) + let query = GeoQuery::Intersects(1000.0, 1000.0, 2000.0, 2000.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let count = row_ids.len().unwrap_or(0) as usize; + assert_eq!(count, 0); + } + _ => panic!("Expected Exact search result"), + } + + // Query in negative space (if data is all positive) + let query = GeoQuery::Intersects(-500.0, -500.0, -100.0, -100.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let count = row_ids.len().unwrap_or(0) as usize; + assert_eq!(count, 0); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_extreme_coordinates() { + use crate::metrics::NoOpMetricsCollector; + + // Test with extremely large and small (but valid) coordinates + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new( + params, + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + + // Add points with extreme coordinates + builder.points.push((1e10, 1e10, 1)); + builder.points.push((-1e10, -1e10, 2)); + builder.points.push((1e-10, 1e-10, 3)); + builder.points.push((-1e-10, -1e-10, 4)); + builder.points.push((f64::MAX / 2.0, f64::MAX / 2.0, 5)); + builder.points.push((f64::MIN / 2.0, f64::MIN / 2.0, 6)); + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Query for large positive values + let query = GeoQuery::Intersects(1e10 - 1.0, 1e10 - 1.0, 1e10 + 1.0, 1e10 + 1.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids.row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + assert!(actual_ids.contains(&1)); + } + _ => panic!("Expected Exact search result"), + } + + // Query for large negative values + let query = GeoQuery::Intersects(-1e10 - 1.0, -1e10 - 1.0, -1e10 + 1.0, -1e10 + 1.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids.row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + assert!(actual_ids.contains(&2)); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_huge_query_bbox() { + use crate::metrics::NoOpMetricsCollector; + + // Test query larger than all data bounds + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new( + params, + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + + // Add points in small region + for i in 0..20 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Query that encompasses ALL data + let query = GeoQuery::Intersects(-1000.0, -1000.0, 1000.0, 1000.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let count = row_ids.len().unwrap_or(0) as usize; + // Should find all 20 points + assert_eq!(count, 20); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_zero_size_query() { + use crate::metrics::NoOpMetricsCollector; + + // Test query box with zero width/height (point query) + let test_store = create_test_store(); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = GeoIndexBuilder::try_new( + params, + DataType::Struct(Vec::::new().into()), + ) + .unwrap(); + + // Add some points + builder.points.push((10.0, 10.0, 1)); + builder.points.push((10.0, 10.0, 2)); // Duplicate at same location + builder.points.push((20.0, 20.0, 3)); + builder.points.push((30.0, 30.0, 4)); + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Zero-width query (line) + let query = GeoQuery::Intersects(10.0, 10.0, 10.0, 20.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids.row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + // Should find points at (10, 10) + assert!(actual_ids.contains(&1)); + assert!(actual_ids.contains(&2)); + } + _ => panic!("Expected Exact search result"), + } + + // Zero-height query (line) + let query = GeoQuery::Intersects(10.0, 10.0, 20.0, 10.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids.row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + // Should find points at (10, 10) + assert!(actual_ids.contains(&1)); + assert!(actual_ids.contains(&2)); + } + _ => panic!("Expected Exact search result"), + } + + // Point query (zero width and height) + let query = GeoQuery::Intersects(20.0, 20.0, 20.0, 20.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids.row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + // Should find point at (20, 20) + assert_eq!(actual_ids.len(), 1); + assert!(actual_ids.contains(&3)); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] async fn test_geo_intersects_large_scale_lazy_loading() { use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; From 2ac6dca33eb94cd7ec84484c8f302750aa4515d6 Mon Sep 17 00:00:00 2001 From: Jay Narale Date: Mon, 20 Oct 2025 10:22:02 -0700 Subject: [PATCH 7/7] change naming --- protos/index.proto | 2 +- protos/index_old.proto | 1 - python/python/lance/dataset.py | 8 +- python/python/tests/test_bkdtree_index.py | 288 +++++ python/python/tests/test_optimize.py | 2 +- python/src/dataset.rs | 6 +- rust/lance-datafusion/src/udf.rs | 77 +- rust/lance-index/benches/geoindex.rs | 106 +- rust/lance-index/src/lib.rs | 10 +- rust/lance-index/src/scalar.rs | 11 +- rust/lance-index/src/scalar/geo/bkd.rs | 132 +- .../scalar/geo/{geoindex.rs => bkdtree.rs} | 1066 ++++++++++------- rust/lance-index/src/scalar/geo/mod.rs | 9 +- rust/lance-index/src/scalar/registry.rs | 4 +- rust/lance/src/index/create.rs | 2 +- test_geoarrow_geo_index.py | 301 ----- 16 files changed, 1116 insertions(+), 909 deletions(-) create mode 100644 python/python/tests/test_bkdtree_index.py rename rust/lance-index/src/scalar/geo/{geoindex.rs => bkdtree.rs} (76%) delete mode 100644 test_geoarrow_geo_index.py diff --git a/protos/index.proto b/protos/index.proto index 837e54abdcc..9e8f6ca9b7d 100644 --- a/protos/index.proto +++ b/protos/index.proto @@ -190,4 +190,4 @@ message JsonIndexDetails { } message BloomFilterIndexDetails {} -message GeoIndexDetails {} \ No newline at end of file +message BkdTreeIndexDetails {} \ No newline at end of file diff --git a/protos/index_old.proto b/protos/index_old.proto index 5931f911380..601aa2681da 100644 --- a/protos/index_old.proto +++ b/protos/index_old.proto @@ -25,7 +25,6 @@ message BitmapIndexDetails {} message LabelListIndexDetails {} message NGramIndexDetails {} message ZoneMapIndexDetails {} -message GeoIndexDetails {} message InvertedIndexDetails { // Marking this field as optional as old versions of the index store blank details and we // need to make sure we have a proper optional field to detect this. diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index b966d19fbfb..f39b48ac7e0 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2345,7 +2345,7 @@ def create_scalar_index( "LABEL_LIST", "INVERTED", "BLOOMFILTER", - "GEO" + "BKDTREE", ]: raise NotImplementedError( ( @@ -2393,7 +2393,7 @@ def create_scalar_index( f"INVERTED index column {column} must be string, large string" " or list of strings, but got {value_type}" ) - elif index_type == "GEO": + elif index_type == "BKDTREE": # Accept struct for GeoArrow point data if pa.types.is_struct(field_type): field_names = [field.name for field in field_type] @@ -2402,12 +2402,12 @@ def create_scalar_index( pass else: raise TypeError( - f"GEO index column {column} must be a struct with x,y fields for point data. " + f"BKDTREE index column {column} must be a struct with x,y fields for point data. " f"Got struct with fields: {field_names}" ) else: raise TypeError( - f"GEO index column {column} must be a struct type. " + f"BKDTREE index column {column} must be a struct type. " f"Got field type: {field_type}" ) if pa.types.is_duration(field_type): diff --git a/python/python/tests/test_bkdtree_index.py b/python/python/tests/test_bkdtree_index.py new file mode 100644 index 00000000000..65e336bf820 --- /dev/null +++ b/python/python/tests/test_bkdtree_index.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +""" +Tests for BKD Tree spatial index on GeoArrow Point data. + +This module tests: +1. Creating GeoArrow Point data +2. Writing to Lance dataset +3. Creating a BKD tree index on GeoArrow Point column +4. Querying with spatial filters +5. Verifying the index is used in query execution +""" + +import os + +import numpy as np +import pyarrow as pa +import pytest + +import lance + +geoarrow = pytest.importorskip("geoarrow.pyarrow") + + +@pytest.fixture +def geoarrow_data(): + """Create GeoArrow Point test data with known cities and random points.""" + np.random.seed(42) + num_points = 5000 + + # Generate random points across the US + # US bounding box: lng [-125, -65], lat [25, 50] + lng_vals = np.random.uniform(-125, -65, num_points) + lat_vals = np.random.uniform(25, 50, num_points) + + # Add known cities at the beginning for testing + known_cities = [ + { + "id": 1, + "city": "San Francisco", + "lng": -122.4194, + "lat": 37.7749, + "population": 883305, + }, + { + "id": 2, + "city": "Los Angeles", + "lng": -118.2437, + "lat": 34.0522, + "population": 3898747, + }, + { + "id": 3, + "city": "New York", + "lng": -74.0060, + "lat": 40.7128, + "population": 8336817, + }, + { + "id": 4, + "city": "Chicago", + "lng": -87.6298, + "lat": 41.8781, + "population": 2746388, + }, + { + "id": 5, + "city": "Houston", + "lng": -95.3698, + "lat": 29.7604, + "population": 2304580, + }, + ] + + # Replace first 5 points with known cities + for i, city in enumerate(known_cities): + lng_vals[i] = city["lng"] + lat_vals[i] = city["lat"] + + start_location = geoarrow.point().from_geobuffers(None, lng_vals, lat_vals) + + # Create IDs and city names + ids = list(range(1, num_points + 1)) + cities = [ + known_cities[i]["city"] if i < len(known_cities) else f"Point_{i + 1}" + for i in range(num_points) + ] + populations = [ + known_cities[i]["population"] + if i < len(known_cities) + else np.random.randint(10000, 1000000) + for i in range(num_points) + ] + + table = pa.table( + { + "id": ids, + "city": cities, + "start_location": start_location, + "population": populations, + } + ) + + return table, known_cities + + +def test_write_geoarrow_to_lance(tmp_path, geoarrow_data): + """Test writing GeoArrow Point data to Lance dataset.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + + # Verify data was written correctly + loaded_table = ds.to_table() + assert len(loaded_table) == len(table) + assert loaded_table.schema.equals(table.schema) + + +def test_create_bkdtree_index(tmp_path, geoarrow_data): + """Test creating a BKD tree index on GeoArrow Point column.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + + # Create BKD tree index + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Verify index was created + indexes = ds.list_indices() + assert len(indexes) > 0 + + # Check that index files exist + index_dir = dataset_path / "_indices" + assert index_dir.exists() + index_files = list(index_dir.rglob("*")) + assert len(index_files) > 0 + + +def test_spatial_query_broad_bbox(tmp_path, geoarrow_data): + """Test spatial query with broad bounding box covering multiple cities.""" + table, known_cities = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Query with broad bbox covering western US + # Should include San Francisco and Los Angeles + sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) + """ + + query = ds.sql(sql).build() + result = query.to_batch_records() + + assert result is not None + result_table = pa.Table.from_batches(result) + + # Should get many results with random points + assert len(result_table) > 100 + + # Should include known cities in the bbox + cities = result_table.column("city").to_pylist() + assert "San Francisco" in cities + assert "Los Angeles" in cities + + +def test_spatial_query_tight_bbox(tmp_path, geoarrow_data): + """Test spatial query with tight bounding box around single city.""" + table, known_cities = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Query with tight bbox around San Francisco only + # SF is at (-122.4194, 37.7749) + sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-123, 37, -122, 38)) + """ + + query = ds.sql(sql).build() + result = query.to_batch_records() + + assert result is not None + result_table = pa.Table.from_batches(result) + + # Should include San Francisco + cities = result_table.column("city").to_pylist() + assert "San Francisco" in cities + + # From known cities, should only include San Francisco + known_cities_in_result = [ + c + for c in cities + if c in ["San Francisco", "Los Angeles", "New York", "Chicago", "Houston"] + ] + assert known_cities_in_result == ["San Francisco"] + + +def test_spatial_query_uses_index(tmp_path, geoarrow_data): + """Test that spatial queries use the BKD tree index via EXPLAIN ANALYZE.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Run EXPLAIN ANALYZE to verify index usage + explain_sql = """ + EXPLAIN ANALYZE SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) + """ + + query = ds.sql(explain_sql).build() + result = query.to_batch_records() + + assert result is not None + explain_table = pa.Table.from_batches(result) + assert len(explain_table) > 0 + + # Check if index was used in the execution plan + # The plan is in the second column + plan_text = str(explain_table.column(1).to_pylist()[0]) + + # Look for evidence of index usage + # (The exact string may vary, adjust based on actual output) + assert "ScalarIndexQuery" in plan_text or "start_location_idx" in plan_text, ( + f"Index not detected in execution plan: {plan_text}" + ) + + +def test_spatial_query_empty_result(tmp_path, geoarrow_data): + """Test spatial query with bbox that doesn't intersect any points.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Query with bbox outside the US (e.g., over the Pacific Ocean) + sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-180, -10, -175, -5)) + """ + + query = ds.sql(sql).build() + result = query.to_batch_records() + + # Should return empty result or very small result + if result: + result_table = pa.Table.from_batches(result) + assert len(result_table) == 0 + + +def test_index_file_structure(tmp_path, geoarrow_data): + """Test that BKD tree index creates expected file structure.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + index_dir = dataset_path / "_indices" + assert index_dir.exists() + + # Check for index subdirectories + index_subdirs = [d for d in index_dir.iterdir() if d.is_dir()] + assert len(index_subdirs) > 0 + + # Check that index files exist and have content + for subdir in index_subdirs: + files = list(subdir.glob("*")) + assert len(files) > 0 + + # Verify files have content + for f in files: + if f.is_file(): + assert f.stat().st_size > 0 diff --git a/python/python/tests/test_optimize.py b/python/python/tests/test_optimize.py index 8bf12db91ae..e334111280e 100644 --- a/python/python/tests/test_optimize.py +++ b/python/python/tests/test_optimize.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright The Lance Authors +# SPDX-FileCopyrightTexPiot: Copyright The Lance Authors import pickle import random import re diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 64c027dc984..b4e3aa9816b 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1601,7 +1601,7 @@ impl Dataset { "BLOOMFILTER" => IndexType::BloomFilter, "LABEL_LIST" => IndexType::LabelList, "INVERTED" | "FTS" => IndexType::Inverted, - "GEO" => IndexType::Geo, + "BKDTREE" => IndexType::BkdTree, "IVF_FLAT" | "IVF_PQ" | "IVF_SQ" | "IVF_RQ" | "IVF_HNSW_FLAT" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, _ => { @@ -1703,8 +1703,8 @@ impl Dataset { } Box::new(params) } - "GEO" => Box::new(ScalarIndexParams { - index_type: "geo".to_string(), + "BKDTREE" => Box::new(ScalarIndexParams { + index_type: "bkdtree".to_string(), params: None, }), _ => { diff --git a/rust/lance-datafusion/src/udf.rs b/rust/lance-datafusion/src/udf.rs index 6af1a81f362..1449cf32ef7 100644 --- a/rust/lance-datafusion/src/udf.rs +++ b/rust/lance-datafusion/src/udf.rs @@ -51,16 +51,22 @@ fn st_intersects() -> ScalarUDF { create_udf( "st_intersects", vec![ - DataType::Struct(vec![ - Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), - ].into()), - DataType::Struct(vec![ - Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), - ].into()) + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), + ] + .into(), + ), + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ] + .into(), + ), ], // GeoArrow Point struct, GeoArrow Box struct DataType::Boolean, Volatility::Immutable, @@ -68,7 +74,6 @@ fn st_intersects() -> ScalarUDF { ) } - fn st_within() -> ScalarUDF { let function = Arc::new(make_scalar_function( |_args: &[ArrayRef]| { @@ -83,16 +88,22 @@ fn st_within() -> ScalarUDF { create_udf( "st_within", vec![ - DataType::Struct(vec![ - Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), - ].into()), - DataType::Struct(vec![ - Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), - ].into()) + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), + ] + .into(), + ), + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ] + .into(), + ), ], // GeoArrow Point struct, GeoArrow Box struct DataType::Boolean, Volatility::Immutable, @@ -100,7 +111,6 @@ fn st_within() -> ScalarUDF { ) } - /// BBOX function that creates a bounding box from four numeric arguments. /// This function is used internally by spatial queries and doesn't perform actual computation. /// It's intercepted by Lance's geo query parser for index optimization. @@ -123,19 +133,26 @@ fn bbox() -> ScalarUDF { create_udf( "bbox", - vec![DataType::Float64, DataType::Float64, DataType::Float64, DataType::Float64], // min_x, min_y, max_x, max_y - DataType::Struct(vec![ - Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), - Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), - ].into()), // Returns a GeoArrow Box struct + vec![ + DataType::Float64, + DataType::Float64, + DataType::Float64, + DataType::Float64, + ], // min_x, min_y, max_x, max_y + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ] + .into(), + ), // Returns a GeoArrow Box struct Volatility::Immutable, function, ) } - /// This method checks whether a string contains all specified tokens. The tokens are separated by /// punctuations and white spaces. /// diff --git a/rust/lance-index/benches/geoindex.rs b/rust/lance-index/benches/geoindex.rs index a5a85a9b6c5..a311fab4736 100644 --- a/rust/lance-index/benches/geoindex.rs +++ b/rust/lance-index/benches/geoindex.rs @@ -8,14 +8,14 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use lance_core::cache::LanceCache; use lance_core::utils::tempfile::TempObjDir; use lance_io::object_store::ObjectStore; -use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::sync::Arc; +use lance_index::metrics::{MetricsCollector, NoOpMetricsCollector}; use lance_index::scalar::geo::geoindex::{GeoIndex, GeoIndexBuilder, GeoIndexBuilderParams}; use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::scalar::{GeoQuery, ScalarIndex}; -use lance_index::metrics::{MetricsCollector, NoOpMetricsCollector}; use std::sync::atomic::{AtomicUsize, Ordering}; struct LeafCounter { @@ -28,11 +28,11 @@ impl LeafCounter { leaves_visited: AtomicUsize::new(0), } } - + fn get_count(&self) -> usize { self.leaves_visited.load(Ordering::Relaxed) } - + fn reset(&self) { self.leaves_visited.store(0, Ordering::Relaxed); } @@ -42,7 +42,7 @@ impl MetricsCollector for LeafCounter { fn record_parts_loaded(&self, num_parts: usize) { self.leaves_visited.fetch_add(num_parts, Ordering::Relaxed); } - + fn record_index_loads(&self, _num_indexes: usize) {} fn record_comparisons(&self, _num_comparisons: usize) {} } @@ -50,42 +50,41 @@ impl MetricsCollector for LeafCounter { fn create_test_points(num_points: usize) -> Vec<(f64, f64, u64)> { let mut rng = StdRng::seed_from_u64(42); let mut points = Vec::with_capacity(num_points); - + for i in 0..num_points { let x = rng.random_range(0.0..1000.0); let y = rng.random_range(0.0..1000.0); points.push((x, y, i as u64)); } - + points } -async fn create_geo_index(points: Vec<(f64, f64, u64)>) -> (Arc, Arc, TempObjDir) { +async fn create_geo_index( + points: Vec<(f64, f64, u64)>, +) -> (Arc, Arc, TempObjDir) { let tmpdir = TempObjDir::default(); let store = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), tmpdir.clone(), Arc::new(LanceCache::no_cache()), )); - + let params = GeoIndexBuilderParams { max_points_per_leaf: 1000, batches_per_file: 10, }; - - let mut builder = GeoIndexBuilder::try_new( - params, - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); - + + let mut builder = + GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + builder.points = points; builder.write_index(store.as_ref()).await.unwrap(); - + let index = GeoIndex::load(store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - + (index, store, tmpdir) } @@ -93,18 +92,18 @@ fn bench_geo_index_intersects_pruning(c: &mut Criterion) { let mut group = c.benchmark_group("geo_intersects_pruning_vs_scan_all"); group.sample_size(10); // Reduce samples for faster benchmarks group.measurement_time(std::time::Duration::from_secs(60)); - + // Compares BKD tree spatial pruning vs scanning all leaves for intersects queries // Both approaches do lazy loading - we're measuring pruning efficiency - + // Test with different dataset sizes for num_points in [10_000, 100_000, 1_000_000] { let points = create_test_points(num_points); - + // Create index (do this once outside the benchmark) let rt = tokio::runtime::Runtime::new().unwrap(); let (index, _store, _tmpdir) = rt.block_on(create_geo_index(points.clone())); - + // Generate random queries let mut rng = StdRng::seed_from_u64(42); let queries: Vec<[f64; 4]> = (0..10) @@ -116,7 +115,7 @@ fn bench_geo_index_intersects_pruning(c: &mut Criterion) { [min_x, min_y, min_x + width, min_y + height] }) .collect(); - + // Benchmark: scan all leaves (no spatial pruning) let scan_all_counter = Arc::new(LeafCounter::new()); group.bench_with_input( @@ -128,13 +127,16 @@ fn bench_geo_index_intersects_pruning(c: &mut Criterion) { counter.reset(); rt.block_on(async { for query_bbox in queries.iter() { - let _result = index.search_all_leaves(*query_bbox, counter.as_ref()).await.unwrap(); + let _result = index + .search_all_leaves(*query_bbox, counter.as_ref()) + .await + .unwrap(); } }); }); }, ); - + // Benchmark: BKD tree with spatial pruning let pruned_counter = Arc::new(LeafCounter::new()); group.bench_with_input( @@ -146,37 +148,49 @@ fn bench_geo_index_intersects_pruning(c: &mut Criterion) { counter.reset(); rt.block_on(async { for query_bbox in queries.iter() { - let query = GeoQuery::Intersects(query_bbox[0], query_bbox[1], query_bbox[2], query_bbox[3]); + let query = GeoQuery::Intersects( + query_bbox[0], + query_bbox[1], + query_bbox[2], + query_bbox[3], + ); let _result = index.search(&query, counter.as_ref()).await.unwrap(); } }); }); }, ); - + // Print statistics with speedup estimate let scan_avg = scan_all_counter.get_count() as f64 / queries.len() as f64; let pruned_avg = pruned_counter.get_count() as f64 / queries.len() as f64; - + // Quick timing for speedup calculation let rt = tokio::runtime::Runtime::new().unwrap(); let scan_start = std::time::Instant::now(); for query_bbox in queries.iter() { rt.block_on(async { - let _ = index.search_all_leaves(*query_bbox, &NoOpMetricsCollector).await; + let _ = index + .search_all_leaves(*query_bbox, &NoOpMetricsCollector) + .await; }); } let scan_time = scan_start.elapsed().as_secs_f64() / queries.len() as f64; - + let prune_start = std::time::Instant::now(); for query_bbox in queries.iter() { rt.block_on(async { - let query = GeoQuery::Intersects(query_bbox[0], query_bbox[1], query_bbox[2], query_bbox[3]); + let query = GeoQuery::Intersects( + query_bbox[0], + query_bbox[1], + query_bbox[2], + query_bbox[3], + ); let _ = index.search(&query, &NoOpMetricsCollector).await; }); } let prune_time = prune_start.elapsed().as_secs_f64() / queries.len() as f64; - + println!( "\n {} points: {} leaves | scan_all: {:.1} leaves/query ({:.1}ms) | with_pruning: {:.1} leaves/query ({:.1}ms) | {:.1}x faster, {:.1}% reduction", num_points, @@ -189,21 +203,21 @@ fn bench_geo_index_intersects_pruning(c: &mut Criterion) { 100.0 * (1.0 - pruned_avg / scan_avg) ); } - + group.finish(); } fn bench_geo_intersects_query_size(c: &mut Criterion) { let mut group = c.benchmark_group("geo_intersects_query_size"); group.sample_size(10); - + // Fixed dataset size, varying intersects query box sizes let points = create_test_points(1_000_000); let rt = tokio::runtime::Runtime::new().unwrap(); let (index, _store, _tmpdir) = rt.block_on(create_geo_index(points)); - + let mut rng = StdRng::seed_from_u64(42); - + // Test different query sizes for query_size in [1.0, 10.0, 50.0, 100.0, 200.0] { let queries: Vec<[f64; 4]> = (0..10) @@ -213,7 +227,7 @@ fn bench_geo_intersects_query_size(c: &mut Criterion) { [min_x, min_y, min_x + query_size, min_y + query_size] }) .collect(); - + let leaf_counter = Arc::new(LeafCounter::new()); group.bench_with_input( BenchmarkId::new("query_size", format!("{}x{}", query_size, query_size)), @@ -224,14 +238,19 @@ fn bench_geo_intersects_query_size(c: &mut Criterion) { counter.reset(); rt.block_on(async { for query_bbox in queries.iter() { - let query = GeoQuery::Intersects(query_bbox[0], query_bbox[1], query_bbox[2], query_bbox[3]); + let query = GeoQuery::Intersects( + query_bbox[0], + query_bbox[1], + query_bbox[2], + query_bbox[3], + ); let _result = index.search(&query, counter.as_ref()).await.unwrap(); } }); }); }, ); - + let avg_leaves = leaf_counter.get_count() as f64 / queries.len() as f64; println!( " Query {}x{}: avg {:.1} leaves visited (out of {} total, {:.1}% selectivity)", @@ -242,10 +261,13 @@ fn bench_geo_intersects_query_size(c: &mut Criterion) { 100.0 * avg_leaves / index.num_leaves() as f64 ); } - + group.finish(); } -criterion_group!(benches, bench_geo_index_intersects_pruning, bench_geo_intersects_query_size); +criterion_group!( + benches, + bench_geo_index_intersects_pruning, + bench_geo_intersects_query_size +); criterion_main!(benches); - diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 7e2e2aae9ec..cec4cafd244 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -108,7 +108,7 @@ pub enum IndexType { BloomFilter = 9, // Bloom filter - Geo = 10, // Geo + BkdTree = 10, // BKD Tree // 100+ and up for vector index. /// Flat vector index. @@ -132,7 +132,7 @@ impl std::fmt::Display for IndexType { Self::NGram => write!(f, "NGram"), Self::FragmentReuse => write!(f, "FragmentReuse"), Self::MemWal => write!(f, "MemWal"), - Self::Geo => write!(f, "Geo"), + Self::BkdTree => write!(f, "BkdTree"), Self::ZoneMap => write!(f, "ZoneMap"), Self::BloomFilter => write!(f, "BloomFilter"), Self::Vector | Self::IvfPq => write!(f, "IVF_PQ"), @@ -159,7 +159,7 @@ impl TryFrom for IndexType { v if v == Self::Inverted as i32 => Ok(Self::Inverted), v if v == Self::FragmentReuse as i32 => Ok(Self::FragmentReuse), v if v == Self::MemWal as i32 => Ok(Self::MemWal), - v if v == Self::Geo as i32 => Ok(Self::Geo), + v if v == Self::BkdTree as i32 => Ok(Self::BkdTree), v if v == Self::ZoneMap as i32 => Ok(Self::ZoneMap), v if v == Self::BloomFilter as i32 => Ok(Self::BloomFilter), v if v == Self::Vector as i32 => Ok(Self::Vector), @@ -218,7 +218,7 @@ impl IndexType { | Self::NGram | Self::ZoneMap | Self::BloomFilter - | Self::Geo + | Self::BkdTree ) } @@ -257,7 +257,7 @@ impl IndexType { Self::MemWal => 0, Self::ZoneMap => 0, Self::BloomFilter => 0, - Self::Geo => 0, + Self::BkdTree => 0, // for now all vector indices are built by the same builder, // so they share the same version. Self::Vector diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 0af533411af..9700635cfa4 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -62,7 +62,7 @@ pub enum BuiltinIndexType { ZoneMap, BloomFilter, Inverted, - Geo, + BkdTree, } impl BuiltinIndexType { @@ -75,7 +75,7 @@ impl BuiltinIndexType { Self::ZoneMap => "zonemap", Self::Inverted => "inverted", Self::BloomFilter => "bloomfilter", - Self::Geo => "geo", + Self::BkdTree => "bkdtree", } } } @@ -92,7 +92,7 @@ impl TryFrom for BuiltinIndexType { IndexType::ZoneMap => Ok(Self::ZoneMap), IndexType::Inverted => Ok(Self::Inverted), IndexType::BloomFilter => Ok(Self::BloomFilter), - IndexType::Geo => Ok(Self::Geo), + IndexType::BkdTree => Ok(Self::BkdTree), _ => Err(Error::Index { message: "Invalid index type".to_string(), location: location!(), @@ -607,7 +607,10 @@ impl AnyQuery for GeoQuery { fn format(&self, col: &str) -> String { match self { Self::Intersects(min_x, min_y, max_x, max_y) => { - format!("st_intersects({}, bbox({}, {}, {}, {}))", col, min_x, min_y, max_x, max_y) + format!( + "st_intersects({}, bbox({}, {}, {}, {}))", + col, min_x, min_y, max_x, max_y + ) } } } diff --git a/rust/lance-index/src/scalar/geo/bkd.rs b/rust/lance-index/src/scalar/geo/bkd.rs index 096324732ef..7b0ebcd026a 100644 --- a/rust/lance-index/src/scalar/geo/bkd.rs +++ b/rust/lance-index/src/scalar/geo/bkd.rs @@ -15,13 +15,13 @@ //! - Groups points into leaves of configurable size //! - Stores tree structure separately from leaf data for lazy loading -use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt64Array}; use arrow_array::cast::AsArray; +use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use deepsize::DeepSizeOf; use lance_core::{Result, ROW_ID}; -use std::sync::Arc; use snafu::location; +use std::sync::Arc; // Schema field names const NODE_ID: &str = "node_id"; @@ -192,37 +192,43 @@ impl BKDTreeLookup { // Helper to get column by name let get_col = |batch: &RecordBatch, name: &str| -> Result { - batch.schema().column_with_name(name) + batch + .schema() + .column_with_name(name) .map(|(idx, _)| idx) .ok_or_else(|| lance_core::Error::Internal { message: format!("Missing column '{}' in BKD tree batch", name), location: location!(), }) }; - + // Determine total number of nodes (max node_id + 1) let max_node_id = { let mut max_id = 0u32; - + if inner_batch.num_rows() > 0 { let col_idx = get_col(&inner_batch, NODE_ID)?; - let node_ids = inner_batch.column(col_idx).as_primitive::(); + let node_ids = inner_batch + .column(col_idx) + .as_primitive::(); for i in 0..inner_batch.num_rows() { max_id = max_id.max(node_ids.value(i)); } } - + if leaf_batch.num_rows() > 0 { let col_idx = get_col(&leaf_batch, NODE_ID)?; - let node_ids = leaf_batch.column(col_idx).as_primitive::(); + let node_ids = leaf_batch + .column(col_idx) + .as_primitive::(); for i in 0..leaf_batch.num_rows() { max_id = max_id.max(node_ids.value(i)); } } - + max_id }; - + // Create sparse array of nodes (filled with dummy data initially) let mut nodes = vec![ BKDNode::Leaf(BKDLeafNode { @@ -233,21 +239,39 @@ impl BKDTreeLookup { }); (max_node_id + 1) as usize ]; - + let mut num_leaves = 0; - + // Fill in inner nodes if inner_batch.num_rows() > 0 { - let node_ids = inner_batch.column(get_col(&inner_batch, NODE_ID)?).as_primitive::(); - let min_x = inner_batch.column(get_col(&inner_batch, MIN_X)?).as_primitive::(); - let min_y = inner_batch.column(get_col(&inner_batch, MIN_Y)?).as_primitive::(); - let max_x = inner_batch.column(get_col(&inner_batch, MAX_X)?).as_primitive::(); - let max_y = inner_batch.column(get_col(&inner_batch, MAX_Y)?).as_primitive::(); - let split_dim = inner_batch.column(get_col(&inner_batch, SPLIT_DIM)?).as_primitive::(); - let split_value = inner_batch.column(get_col(&inner_batch, SPLIT_VALUE)?).as_primitive::(); - let left_child = inner_batch.column(get_col(&inner_batch, LEFT_CHILD)?).as_primitive::(); - let right_child = inner_batch.column(get_col(&inner_batch, RIGHT_CHILD)?).as_primitive::(); - + let node_ids = inner_batch + .column(get_col(&inner_batch, NODE_ID)?) + .as_primitive::(); + let min_x = inner_batch + .column(get_col(&inner_batch, MIN_X)?) + .as_primitive::(); + let min_y = inner_batch + .column(get_col(&inner_batch, MIN_Y)?) + .as_primitive::(); + let max_x = inner_batch + .column(get_col(&inner_batch, MAX_X)?) + .as_primitive::(); + let max_y = inner_batch + .column(get_col(&inner_batch, MAX_Y)?) + .as_primitive::(); + let split_dim = inner_batch + .column(get_col(&inner_batch, SPLIT_DIM)?) + .as_primitive::(); + let split_value = inner_batch + .column(get_col(&inner_batch, SPLIT_VALUE)?) + .as_primitive::(); + let left_child = inner_batch + .column(get_col(&inner_batch, LEFT_CHILD)?) + .as_primitive::(); + let right_child = inner_batch + .column(get_col(&inner_batch, RIGHT_CHILD)?) + .as_primitive::(); + for i in 0..inner_batch.num_rows() { let node_id = node_ids.value(i) as usize; nodes[node_id] = BKDNode::Inner(BKDInnerNode { @@ -264,18 +288,34 @@ impl BKDTreeLookup { }); } } - + // Fill in leaf nodes if leaf_batch.num_rows() > 0 { - let node_ids = leaf_batch.column(get_col(&leaf_batch, NODE_ID)?).as_primitive::(); - let min_x = leaf_batch.column(get_col(&leaf_batch, MIN_X)?).as_primitive::(); - let min_y = leaf_batch.column(get_col(&leaf_batch, MIN_Y)?).as_primitive::(); - let max_x = leaf_batch.column(get_col(&leaf_batch, MAX_X)?).as_primitive::(); - let max_y = leaf_batch.column(get_col(&leaf_batch, MAX_Y)?).as_primitive::(); - let file_id = leaf_batch.column(get_col(&leaf_batch, FILE_ID)?).as_primitive::(); - let row_offset = leaf_batch.column(get_col(&leaf_batch, ROW_OFFSET)?).as_primitive::(); - let num_rows = leaf_batch.column(get_col(&leaf_batch, NUM_ROWS)?).as_primitive::(); - + let node_ids = leaf_batch + .column(get_col(&leaf_batch, NODE_ID)?) + .as_primitive::(); + let min_x = leaf_batch + .column(get_col(&leaf_batch, MIN_X)?) + .as_primitive::(); + let min_y = leaf_batch + .column(get_col(&leaf_batch, MIN_Y)?) + .as_primitive::(); + let max_x = leaf_batch + .column(get_col(&leaf_batch, MAX_X)?) + .as_primitive::(); + let max_y = leaf_batch + .column(get_col(&leaf_batch, MAX_Y)?) + .as_primitive::(); + let file_id = leaf_batch + .column(get_col(&leaf_batch, FILE_ID)?) + .as_primitive::(); + let row_offset = leaf_batch + .column(get_col(&leaf_batch, ROW_OFFSET)?) + .as_primitive::(); + let num_rows = leaf_batch + .column(get_col(&leaf_batch, NUM_ROWS)?) + .as_primitive::(); + for i in 0..leaf_batch.num_rows() { let node_id = node_ids.value(i) as usize; nodes[node_id] = BKDNode::Leaf(BKDLeafNode { @@ -292,10 +332,9 @@ impl BKDTreeLookup { num_leaves += 1; } } - + Ok(Self::new(nodes, 0, num_leaves)) } - } /// Check if two bounding boxes intersect @@ -321,7 +360,11 @@ impl BKDTreeBuilder { /// Build a BKD tree from points /// Returns (tree_nodes, leaf_batches) - pub fn build(&self, points: &mut [(f64, f64, u64)], batches_per_file: u32) -> Result<(Vec, Vec)> { + pub fn build( + &self, + points: &mut [(f64, f64, u64)], + batches_per_file: u32, + ) -> Result<(Vec, Vec)> { if points.is_empty() { return Ok((vec![], vec![])); } @@ -346,24 +389,24 @@ impl BKDTreeBuilder { let mut row_offset_in_file = 0u64; let mut batches_in_current_file = 0u32; let mut leaf_idx = 0; - + for node in all_nodes.iter_mut() { if let BKDNode::Leaf(leaf) = node { if leaf_idx < all_leaf_batches.len() { let batch_num_rows = all_leaf_batches[leaf_idx].num_rows() as u64; - + // Check if we need to move to next file if batches_in_current_file >= batches_per_file && batches_per_file > 0 { current_file_id += 1; row_offset_in_file = 0; batches_in_current_file = 0; } - + // Update leaf with correct metadata leaf.file_id = current_file_id; leaf.row_offset = row_offset_in_file; leaf.num_rows = batch_num_rows; - + // Advance for next leaf row_offset_in_file += batch_num_rows; batches_in_current_file += 1; @@ -402,8 +445,8 @@ impl BKDTreeBuilder { // Create leaf node (file_id, row_offset will be set in post-processing) all_nodes.push(BKDNode::Leaf(BKDLeafNode { bounds: [min_x, min_y, max_x, max_y], - file_id: 0, // Will be updated in post-processing - row_offset: 0, // Will be updated in post-processing + file_id: 0, // Will be updated in post-processing + row_offset: 0, // Will be updated in post-processing num_rows, })); @@ -417,7 +460,7 @@ impl BKDTreeBuilder { // Current: O(n log n) sorting at each level = O(n log² n) total // Target: O(n) radix select at each level = O(n log n) total // See: https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java - + // Sort points by the split dimension if split_dim == 0 { points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); @@ -435,7 +478,7 @@ impl BKDTreeBuilder { // Calculate bounds for this node (before splitting the slice) let (min_x, min_y, max_x, max_y) = calculate_bounds(points); - + // Reserve space for this inner node (placeholder - we'll update it after building children) let node_id = all_nodes.len() as u32; all_nodes.push(BKDNode::Inner(BKDInnerNode { @@ -518,8 +561,6 @@ fn create_leaf_batch(points: &[(f64, f64, u64)]) -> Result { Ok(batch) } - - #[cfg(test)] mod tests { use super::*; @@ -839,4 +880,3 @@ mod tests { assert_eq!(batches[1].num_rows(), 2); } } - diff --git a/rust/lance-index/src/scalar/geo/geoindex.rs b/rust/lance-index/src/scalar/geo/bkdtree.rs similarity index 76% rename from rust/lance-index/src/scalar/geo/geoindex.rs rename to rust/lance-index/src/scalar/geo/bkdtree.rs index d860ebc609a..6396a5128b5 100644 --- a/rust/lance-index/src/scalar/geo/geoindex.rs +++ b/rust/lance-index/src/scalar/geo/bkdtree.rs @@ -1,46 +1,44 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -//! Geo Index +//! BKD Tree Index //! -//! Geo indices are spatial database structures for efficient spatial queries. +//! BKD (Bounding K-Dimensional) trees are spatial database structures for efficient spatial queries. //! They enable efficient filtering by location-based predicates. //! //! ## Requirements //! -//! Geo indices can only be created on fields with GeoArrow metadata. The field must: +//! BKD tree indices can only be created on fields with GeoArrow metadata. The field must: //! - Be a Struct data type //! - Have `ARROW:extension:name` metadata starting with `geoarrow.` (e.g., `geoarrow.point`, `geoarrow.polygon`) //! //! ## Query Support //! -//! Geo indices are "inexact" filters - they can definitively exclude regions but may include +//! BKD tree indices are "inexact" filters - they can definitively exclude regions but may include //! false positives that require rechecking. //! +use super::bkd::{self, point_in_bbox, BKDLeafNode, BKDNode, BKDTreeBuilder, BKDTreeLookup}; use crate::pbold; -use super::bkd::{self, BKDTreeBuilder, BKDTreeLookup, point_in_bbox, BKDNode, BKDLeafNode}; use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser}; use crate::scalar::registry::{ ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest, }; -use crate::scalar::{ - BuiltinIndexType, CreatedIndex, GeoQuery, ScalarIndexParams, UpdateCriteria, -}; +use crate::scalar::{BuiltinIndexType, CreatedIndex, GeoQuery, ScalarIndexParams, UpdateCriteria}; use crate::Any; use futures::TryStreamExt; use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; use lance_core::{ROW_ADDR, ROW_ID}; use serde::{Deserialize, Serialize}; -use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt8Array}; use arrow_array::cast::AsArray; +use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt8Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::execution::SendableRecordBatchStream; use std::{collections::HashMap, sync::Arc}; -use crate::scalar::{AnyQuery, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; use crate::scalar::FragReuseIndex; +use crate::scalar::{AnyQuery, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; use crate::vector::VectorIndex; use crate::{Index, IndexType}; use async_trait::async_trait; @@ -54,11 +52,11 @@ const BKD_TREE_INNER_FILENAME: &str = "bkd_tree_inner.lance"; const BKD_TREE_LEAF_FILENAME: &str = "bkd_tree_leaf.lance"; const LEAF_GROUP_PREFIX: &str = "leaf_group_"; const DEFAULT_BATCHES_PER_LEAF_FILE: u32 = 5; // Default number of leaf batches per file -const GEO_INDEX_VERSION: u32 = 0; +const BKD_TREE_VERSION: u32 = 0; const MAX_POINTS_PER_LEAF_META_KEY: &str = "max_points_per_leaf"; const BATCHES_PER_FILE_META_KEY: &str = "batches_per_file"; const DEFAULT_MAX_POINTS_PER_LEAF: u32 = 100; // for test -// const DEFAULT_MAX_POINTS_PER_LEAF: u32 = 1024; // for production + // const DEFAULT_MAX_POINTS_PER_LEAF: u32 = 1024; // for production /// Get the file name for a leaf group fn leaf_group_filename(group_id: u32) -> String { @@ -100,8 +98,8 @@ impl CachedLeafData { } } -/// Geo index -pub struct GeoIndex { +/// BKD tree index for spatial queries +pub struct BkdTree { data_type: DataType, store: Arc, fri: Option>, @@ -110,9 +108,9 @@ pub struct GeoIndex { max_points_per_leaf: u32, } -impl std::fmt::Debug for GeoIndex { +impl std::fmt::Debug for BkdTree { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("GeoIndex") + f.debug_struct("BkdTree") .field("data_type", &self.data_type) .field("store", &self.store) .field("fri", &self.fri) @@ -123,13 +121,13 @@ impl std::fmt::Debug for GeoIndex { } } -impl DeepSizeOf for GeoIndex { +impl DeepSizeOf for BkdTree { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.bkd_tree.deep_size_of_children(context) + self.store.deep_size_of_children(context) } } -impl GeoIndex { +impl BkdTree { /// Load the geo index from storage pub async fn load( store: Arc, @@ -147,9 +145,7 @@ impl GeoIndex { // Load leaf metadata let leaf_file = store.open_index_file(BKD_TREE_LEAF_FILENAME).await?; - let leaf_data = leaf_file - .read_range(0..leaf_file.num_rows(), None) - .await?; + let leaf_data = leaf_file.read_range(0..leaf_file.num_rows(), None).await?; // Deserialize tree structure from both files let bkd_tree = BKDTreeLookup::from_record_batches(inner_data, leaf_data)?; @@ -184,26 +180,27 @@ impl GeoIndex { let file_id = leaf.file_id; let row_offset = leaf.row_offset; let num_rows = leaf.num_rows; - + // Use (file_id, row_offset) as cache key // Combine file_id and row_offset into a single u32 (file_id should be small) - let cache_key = BKDLeafKey { leaf_id: file_id * 100_000 + (row_offset as u32) }; + let cache_key = BKDLeafKey { + leaf_id: file_id * 100_000 + (row_offset as u32), + }; let store = self.store.clone(); let cached = self .index_cache .get_or_insert_with_key(cache_key, move || async move { metrics.record_part_load(); - + let filename = leaf_group_filename(file_id); - + // Open the leaf group file and read the specific row range let reader = store.open_index_file(&filename).await?; - let batch = reader.read_range( - row_offset as usize..(row_offset + num_rows) as usize, - None - ).await?; - + let batch = reader + .read_range(row_offset as usize..(row_offset + num_rows) as usize, None) + .await?; + Ok(CachedLeafData::new(batch)) }) .await?; @@ -223,19 +220,20 @@ impl GeoIndex { metrics: &dyn MetricsCollector, ) -> Result { let mut all_row_ids = RowIdTreeMap::new(); - + // Iterate through all nodes and search every leaf for node in self.bkd_tree.nodes.iter() { if let BKDNode::Leaf(leaf) = node { let leaf_row_ids = self.search_leaf(leaf, query_bbox, metrics).await?; - let row_ids: Option> = leaf_row_ids.row_ids() + let row_ids: Option> = leaf_row_ids + .row_ids() .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); if let Some(row_ids) = row_ids { all_row_ids.extend(row_ids); } } } - + Ok(all_row_ids) } @@ -249,22 +247,28 @@ impl GeoIndex { let leaf_data = self.load_leaf(leaf, metrics).await?; // Filter points within this leaf using iterators - let x_array = leaf_data.column(0).as_primitive::(); - let y_array = leaf_data.column(1).as_primitive::(); - let row_id_array = leaf_data.column(2).as_primitive::(); + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); let row_ids: Vec = x_array .iter() .zip(y_array.iter()) .zip(row_id_array.iter()) - .filter_map(|((x_opt, y_opt), row_id_opt)| { - match (x_opt, y_opt, row_id_opt) { + .filter_map( + |((x_opt, y_opt), row_id_opt)| match (x_opt, y_opt, row_id_opt) { (Some(x), Some(y), Some(row_id)) if point_in_bbox(x, y, &query_bbox) => { Some(row_id) } _ => None, - } - }) + }, + ) .collect(); let mut row_id_map = RowIdTreeMap::new(); @@ -274,7 +278,7 @@ impl GeoIndex { } #[async_trait] -impl Index for GeoIndex { +impl Index for BkdTree { fn as_any(&self) -> &dyn Any { self } @@ -285,7 +289,7 @@ impl Index for GeoIndex { fn as_vector_index(self: Arc) -> Result> { Err(Error::InvalidInput { - source: "GeoIndex is not a vector index".into(), + source: "BkdTree is not a vector index".into(), location: location!(), }) } @@ -297,12 +301,12 @@ impl Index for GeoIndex { fn statistics(&self) -> Result { Ok(serde_json::json!({ - "type": "geo", + "type": "bkdtree", })) } fn index_type(&self) -> IndexType { - IndexType::Geo + IndexType::BkdTree } async fn calculate_included_frags(&self) -> Result { @@ -312,17 +316,20 @@ impl Index for GeoIndex { } #[async_trait] -impl ScalarIndex for GeoIndex { +impl ScalarIndex for BkdTree { async fn search( &self, query: &dyn AnyQuery, metrics: &dyn MetricsCollector, ) -> Result { - let geo_query = query.as_any().downcast_ref::() - .ok_or_else(|| Error::InvalidInput { - source: "Geo index only supports GeoQuery".into(), - location: location!(), - })?; + let geo_query = + query + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Geo index only supports GeoQuery".into(), + location: location!(), + })?; match geo_query { GeoQuery::Intersects(min_x, min_y, max_x, max_y) => { @@ -335,11 +342,10 @@ impl ScalarIndex for GeoIndex { let mut all_row_ids = RowIdTreeMap::new(); for leaf_node in &leaves { - let leaf_row_ids = self - .search_leaf(leaf_node, query_bbox, metrics) - .await?; + let leaf_row_ids = self.search_leaf(leaf_node, query_bbox, metrics).await?; // Collect row IDs from the leaf and add them to the result set - let row_ids: Option> = leaf_row_ids.row_ids() + let row_ids: Option> = leaf_row_ids + .row_ids() .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); if let Some(row_ids) = row_ids { all_row_ids.extend(row_ids); @@ -363,7 +369,7 @@ impl ScalarIndex for GeoIndex { _dest_store: &dyn IndexStore, ) -> Result { Err(Error::InvalidInput { - source: "GeoIndex does not support remap".into(), + source: "BkdTree does not support remap".into(), location: location!(), }) } @@ -375,7 +381,7 @@ impl ScalarIndex for GeoIndex { _dest_store: &dyn IndexStore, ) -> Result { Err(Error::InvalidInput { - source: "GeoIndex does not support update".into(), + source: "BkdTree does not support update".into(), location: location!(), }) } @@ -387,13 +393,13 @@ impl ScalarIndex for GeoIndex { } fn derive_index_params(&self) -> Result { - let params = serde_json::to_value(GeoIndexBuilderParams::default())?; - Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::Geo).with_params(¶ms)) + let params = serde_json::to_value(BkdTreeBuilderParams::default())?; + Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::BkdTree).with_params(¶ms)) } } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeoIndexBuilderParams { +pub struct BkdTreeBuilderParams { #[serde(default = "default_max_points_per_leaf")] pub max_points_per_leaf: u32, #[serde(default = "default_batches_per_file")] @@ -408,7 +414,7 @@ fn default_batches_per_file() -> u32 { DEFAULT_BATCHES_PER_LEAF_FILE } -impl Default for GeoIndexBuilderParams { +impl Default for BkdTreeBuilderParams { fn default() -> Self { Self { max_points_per_leaf: default_max_points_per_leaf(), @@ -417,7 +423,7 @@ impl Default for GeoIndexBuilderParams { } } -impl GeoIndexBuilderParams { +impl BkdTreeBuilderParams { pub fn new() -> Self { Self::default() } @@ -429,15 +435,15 @@ impl GeoIndexBuilderParams { } // A builder for geo index -pub struct GeoIndexBuilder { - options: GeoIndexBuilderParams, +pub struct BkdTreeBuilder { + options: BkdTreeBuilderParams, _items_type: DataType, // Accumulated points: (x, y, row_id) pub points: Vec<(f64, f64, u64)>, } -impl GeoIndexBuilder { - pub fn try_new(options: GeoIndexBuilderParams, items_type: DataType) -> Result { +impl BkdTreeBuilder { + pub fn try_new(options: BkdTreeBuilderParams, items_type: DataType) -> Result { Ok(Self { options, _items_type: items_type, @@ -452,7 +458,10 @@ impl GeoIndexBuilder { while let Some(batch) = batches_source.try_next().await? { // Extract GeoArrow point coordinates - let geom_array = batch.column(0).as_any().downcast_ref::() + let geom_array = batch + .column(0) + .as_any() + .downcast_ref::() .ok_or_else(|| Error::InvalidInput { source: "Expected Struct array for GeoArrow data".into(), location: location!(), @@ -470,11 +479,8 @@ impl GeoIndexBuilder { .as_primitive::(); for i in 0..batch.num_rows() { - self.points.push(( - x_array.value(i), - y_array.value(i), - row_ids.value(i), - )); + self.points + .push((x_array.value(i), y_array.value(i), row_ids.value(i))); } } @@ -509,7 +515,7 @@ impl GeoIndexBuilder { // Write tree structure to separate inner and leaf files let (inner_batch, leaf_metadata_batch) = self.serialize_tree_nodes(&tree_nodes)?; - + // Write inner nodes let mut inner_file = index_store .new_index_file(BKD_TREE_INNER_FILENAME, inner_batch.schema()) @@ -517,8 +523,14 @@ impl GeoIndexBuilder { inner_file.write_record_batch(inner_batch).await?; inner_file .finish_with_metadata(HashMap::from([ - (MAX_POINTS_PER_LEAF_META_KEY.to_string(), self.options.max_points_per_leaf.to_string()), - (BATCHES_PER_FILE_META_KEY.to_string(), self.options.batches_per_file.to_string()), + ( + MAX_POINTS_PER_LEAF_META_KEY.to_string(), + self.options.max_points_per_leaf.to_string(), + ), + ( + BATCHES_PER_FILE_META_KEY.to_string(), + self.options.batches_per_file.to_string(), + ), ])) .await?; @@ -526,7 +538,9 @@ impl GeoIndexBuilder { let mut leaf_meta_file = index_store .new_index_file(BKD_TREE_LEAF_FILENAME, leaf_metadata_batch.schema()) .await?; - leaf_meta_file.write_record_batch(leaf_metadata_batch).await?; + leaf_meta_file + .write_record_batch(leaf_metadata_batch) + .await?; leaf_meta_file.finish().await?; // Write actual leaf data grouped into files (multiple batches per file) @@ -535,61 +549,57 @@ impl GeoIndexBuilder { Field::new("y", DataType::Float64, false), Field::new(ROW_ID, DataType::UInt64, false), ])); - + let batches_per_file = self.options.batches_per_file; let num_groups = (leaf_batches.len() as u32 + batches_per_file - 1) / batches_per_file; - + for group_id in 0..num_groups { let start_idx = (group_id * batches_per_file) as usize; - let end_idx = ((group_id + 1) * batches_per_file).min(leaf_batches.len() as u32) as usize; + let end_idx = + ((group_id + 1) * batches_per_file).min(leaf_batches.len() as u32) as usize; let group_batches = &leaf_batches[start_idx..end_idx]; - + let filename = leaf_group_filename(group_id); - + let mut leaf_file = index_store .new_index_file(&filename, leaf_schema.clone()) .await?; - + for leaf_batch in group_batches.iter() { leaf_file.write_record_batch(leaf_batch.clone()).await?; } - + leaf_file.finish().await?; } - log::debug!( - "Wrote BKD tree with {} nodes", - tree_nodes.len() - ); + log::debug!("Wrote BKD tree with {} nodes", tree_nodes.len()); Ok(()) } // Serialize tree nodes to separate inner and leaf RecordBatches fn serialize_tree_nodes(&self, nodes: &[BKDNode]) -> Result<(RecordBatch, RecordBatch)> { - // Separate inner and leaf nodes with their indices let mut inner_nodes = Vec::new(); let mut leaf_nodes = Vec::new(); - + for (idx, node) in nodes.iter().enumerate() { match node { BKDNode::Inner(_) => inner_nodes.push((idx as u32, node)), BKDNode::Leaf(_) => leaf_nodes.push((idx as u32, node)), } } - + // Serialize inner nodes let inner_batch = Self::serialize_inner_nodes(&inner_nodes)?; - + // Serialize leaf nodes let leaf_batch = Self::serialize_leaf_nodes(&leaf_nodes)?; - + Ok((inner_batch, leaf_batch)) } - + fn serialize_inner_nodes(nodes: &[(u32, &BKDNode)]) -> Result { - let mut node_id_vals = Vec::with_capacity(nodes.len()); let mut min_x_vals = Vec::with_capacity(nodes.len()); let mut min_y_vals = Vec::with_capacity(nodes.len()); @@ -630,10 +640,10 @@ impl GeoIndexBuilder { Ok(RecordBatch::try_new(schema, columns)?) } - + fn serialize_leaf_nodes(nodes: &[(u32, &BKDNode)]) -> Result { use arrow_array::UInt64Array; - + let mut node_id_vals = Vec::with_capacity(nodes.len()); let mut min_x_vals = Vec::with_capacity(nodes.len()); let mut min_y_vals = Vec::with_capacity(nodes.len()); @@ -680,17 +690,17 @@ impl GeoIndexBuilder { } #[derive(Debug, Default)] -pub struct GeoIndexPlugin; +pub struct BkdTreePlugin; -impl GeoIndexPlugin { +impl BkdTreePlugin { async fn train_geo_index( batches_source: SendableRecordBatchStream, index_store: &dyn IndexStore, - options: Option, + options: Option, ) -> Result<()> { let value_type = batches_source.schema().field(0).data_type().clone(); - let mut builder = GeoIndexBuilder::try_new(options.unwrap_or_default(), value_type)?; + let mut builder = BkdTreeBuilder::try_new(options.unwrap_or_default(), value_type)?; builder.train(batches_source).await?; @@ -699,13 +709,13 @@ impl GeoIndexPlugin { } } -pub struct GeoIndexTrainingRequest { - pub params: GeoIndexBuilderParams, +pub struct BkdTreeTrainingRequest { + pub params: BkdTreeBuilderParams, pub criteria: TrainingCriteria, } -impl GeoIndexTrainingRequest { - pub fn new(params: GeoIndexBuilderParams) -> Self { +impl BkdTreeTrainingRequest { + pub fn new(params: BkdTreeBuilderParams) -> Self { Self { params, criteria: TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), @@ -713,7 +723,7 @@ impl GeoIndexTrainingRequest { } } -impl TrainingRequest for GeoIndexTrainingRequest { +impl TrainingRequest for BkdTreeTrainingRequest { fn as_any(&self) -> &dyn std::any::Any { self } @@ -723,7 +733,7 @@ impl TrainingRequest for GeoIndexTrainingRequest { } #[async_trait] -impl ScalarIndexPlugin for GeoIndexPlugin { +impl ScalarIndexPlugin for BkdTreePlugin { fn new_training_request( &self, params: &str, @@ -756,9 +766,9 @@ impl ScalarIndexPlugin for GeoIndexPlugin { }); } - let params = serde_json::from_str::(params)?; + let params = serde_json::from_str::(params)?; - Ok(Box::new(GeoIndexTrainingRequest::new(params))) + Ok(Box::new(BkdTreeTrainingRequest::new(params))) } fn provides_exact_answer(&self) -> bool { @@ -766,7 +776,7 @@ impl ScalarIndexPlugin for GeoIndexPlugin { } fn version(&self) -> u32 { - GEO_INDEX_VERSION + BKD_TREE_VERSION } fn new_query_parser( @@ -792,16 +802,16 @@ impl ScalarIndexPlugin for GeoIndexPlugin { } let request = (request as Box) - .downcast::() + .downcast::() .map_err(|_| Error::InvalidInput { source: "must provide training request created by new_training_request".into(), location: location!(), })?; Self::train_geo_index(data, index_store, Some(request.params)).await?; Ok(CreatedIndex { - index_details: prost_types::Any::from_msg(&pbold::GeoIndexDetails::default()) + index_details: prost_types::Any::from_msg(&crate::pb::BkdTreeIndexDetails::default()) .unwrap(), - index_version: GEO_INDEX_VERSION, + index_version: BKD_TREE_VERSION, }) } @@ -812,7 +822,7 @@ impl ScalarIndexPlugin for GeoIndexPlugin { frag_reuse_index: Option>, cache: &LanceCache, ) -> Result> { - Ok(GeoIndex::load(index_store, frag_reuse_index, cache).await? as Arc) + Ok(BkdTree::load(index_store, frag_reuse_index, cache).await? as Arc) } } @@ -830,7 +840,7 @@ mod tests { fn create_test_store() -> Arc { let tmpdir = TempObjDir::default(); - + let test_store = Arc::new(LanceIndexStore::new( Arc::new(ObjectStore::local()), tmpdir.clone(), @@ -840,11 +850,11 @@ mod tests { } /// Validates the BKD tree structure recursively - async fn validate_bkd_tree(index: &GeoIndex) -> Result<()> { + async fn validate_bkd_tree(index: &BkdTree) -> Result<()> { use crate::metrics::NoOpMetricsCollector; - + let tree = &index.bkd_tree; - + // Verify root exists if tree.nodes.is_empty() { return Err(Error::InvalidInput { @@ -856,7 +866,12 @@ mod tests { let root_id = tree.root_id as usize; if root_id >= tree.nodes.len() { return Err(Error::InvalidInput { - source: format!("Root node id {} out of bounds (tree has {} nodes)", root_id, tree.nodes.len()).into(), + source: format!( + "Root node id {} out of bounds (tree has {} nodes)", + root_id, + tree.nodes.len() + ) + .into(), location: location!(), }); } @@ -864,16 +879,30 @@ mod tests { // Count leaves and validate structure let mut visited = vec![false; tree.nodes.len()]; let mut leaf_count = 0u32; - + // Start with None for expected_split_dim - root can use any dimension // (typically starts with dimension 0, but we don't enforce it) let metrics = NoOpMetricsCollector {}; - validate_node_recursive(index, tree, root_id as u32, &mut visited, &mut leaf_count, None, None, &metrics).await?; + validate_node_recursive( + index, + tree, + root_id as u32, + &mut visited, + &mut leaf_count, + None, + None, + &metrics, + ) + .await?; // Verify all leaves were counted if leaf_count != tree.num_leaves { return Err(Error::InvalidInput { - source: format!("Leaf count mismatch: found {} leaves but tree claims {}", leaf_count, tree.num_leaves).into(), + source: format!( + "Leaf count mismatch: found {} leaves but tree claims {}", + leaf_count, tree.num_leaves + ) + .into(), location: location!(), }); } @@ -893,7 +922,7 @@ mod tests { /// Helper function to recursively validate a node and its descendants async fn validate_node_recursive( - index: &GeoIndex, + index: &BkdTree, tree: &BKDTreeLookup, node_id: u32, visited: &mut Vec, @@ -903,11 +932,16 @@ mod tests { metrics: &dyn MetricsCollector, ) -> Result<()> { let node_idx = node_id as usize; - + // Check node exists if node_idx >= tree.nodes.len() { return Err(Error::InvalidInput { - source: format!("Node id {} out of bounds (tree has {} nodes)", node_id, tree.nodes.len()).into(), + source: format!( + "Node id {} out of bounds (tree has {} nodes)", + node_id, + tree.nodes.len() + ) + .into(), location: location!(), }); } @@ -927,22 +961,36 @@ mod tests { // Validate bounds are well-formed if bounds[0] > bounds[2] || bounds[1] > bounds[3] { return Err(Error::InvalidInput { - source: format!("Node {} has invalid bounds: [{}, {}, {}, {}]", - node_id, bounds[0], bounds[1], bounds[2], bounds[3]).into(), + source: format!( + "Node {} has invalid bounds: [{}, {}, {}, {}]", + node_id, bounds[0], bounds[1], bounds[2], bounds[3] + ) + .into(), location: location!(), }); } // Verify child bounds are within parent bounds if let Some(parent_bounds) = parent_bounds { - if bounds[0] < parent_bounds[0] || bounds[1] < parent_bounds[1] || - bounds[2] > parent_bounds[2] || bounds[3] > parent_bounds[3] { + if bounds[0] < parent_bounds[0] + || bounds[1] < parent_bounds[1] + || bounds[2] > parent_bounds[2] + || bounds[3] > parent_bounds[3] + { return Err(Error::InvalidInput { source: format!( "Node {} bounds [{}, {}, {}, {}] exceed parent bounds [{}, {}, {}, {}]", - node_id, bounds[0], bounds[1], bounds[2], bounds[3], - parent_bounds[0], parent_bounds[1], parent_bounds[2], parent_bounds[3] - ).into(), + node_id, + bounds[0], + bounds[1], + bounds[2], + bounds[3], + parent_bounds[0], + parent_bounds[1], + parent_bounds[2], + parent_bounds[3] + ) + .into(), location: location!(), }); } @@ -953,20 +1001,33 @@ mod tests { // Validate split dimension if inner.split_dim > 1 { return Err(Error::InvalidInput { - source: format!("Node {} has invalid split_dim: {} (must be 0 or 1)", node_id, inner.split_dim).into(), + source: format!( + "Node {} has invalid split_dim: {} (must be 0 or 1)", + node_id, inner.split_dim + ) + .into(), location: location!(), }); } // Validate split value is within bounds - let min_val = if inner.split_dim == 0 { bounds[0] } else { bounds[1] }; - let max_val = if inner.split_dim == 0 { bounds[2] } else { bounds[3] }; + let min_val = if inner.split_dim == 0 { + bounds[0] + } else { + bounds[1] + }; + let max_val = if inner.split_dim == 0 { + bounds[2] + } else { + bounds[3] + }; if inner.split_value < min_val || inner.split_value > max_val { return Err(Error::InvalidInput { source: format!( "Node {} split_value {} is outside dimension bounds [{}, {}]", node_id, inner.split_value, min_val, max_val - ).into(), + ) + .into(), location: location!(), }); } @@ -976,16 +1037,16 @@ mod tests { // Which means left_max <= right_min (no overlap) let left_node = &tree.nodes[inner.left_child as usize]; let right_node = &tree.nodes[inner.right_child as usize]; - + let left_bounds = left_node.bounds(); let right_bounds = right_node.bounds(); - + let (left_max, right_min) = if inner.split_dim == 0 { - (left_bounds[2], right_bounds[0]) // max_x of left, min_x of right + (left_bounds[2], right_bounds[0]) // max_x of left, min_x of right } else { - (left_bounds[3], right_bounds[1]) // max_y of left, min_y of right + (left_bounds[3], right_bounds[1]) // max_y of left, min_y of right }; - + // Simple check: left_max should not be greater than right_min // If it is, points were not properly sorted before splitting! if left_max > right_min { @@ -1013,7 +1074,8 @@ mod tests { Some(bounds), Some((inner.split_dim, inner.split_value, true)), metrics, - )).await?; + )) + .await?; Box::pin(validate_node_recursive( index, tree, @@ -1023,7 +1085,8 @@ mod tests { Some(bounds), Some((inner.split_dim, inner.split_value, false)), metrics, - )).await?; + )) + .await?; } bkd::BKDNode::Leaf(leaf) => { // Validate leaf data @@ -1045,20 +1108,21 @@ mod tests { #[tokio::test] async fn test_geo_intersects_with_custom_max_points_per_leaf() { - use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; - + use rand::{Rng, SeedableRng}; + // Test with different max points per leaf for max_points_per_leaf in [10, 50, 100, 200] { let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf, batches_per_file: 5, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); // Create 500 RANDOM points (not sequential!) to catch sorting bugs let mut rng = StdRng::seed_from_u64(42); @@ -1072,17 +1136,17 @@ mod tests { builder.write_index(test_store.as_ref()).await.unwrap(); // Load index and verify - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await - .expect("Failed to load GeoIndex"); + .expect("Failed to load BkdTree"); assert_eq!(index.max_points_per_leaf, max_points_per_leaf); - - - + // Validate tree structure - validate_bkd_tree(&index).await - .expect(&format!("BKD tree validation failed for max_points_per_leaf={}", max_points_per_leaf)); + validate_bkd_tree(&index).await.expect(&format!( + "BKD tree validation failed for max_points_per_leaf={}", + max_points_per_leaf + )); } } @@ -1091,14 +1155,15 @@ mod tests { // Test with different batches_per_file configurations for batches_per_file in [1, 3, 5, 10, 20] { let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 50, batches_per_file, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); // Create 500 points (BKD spatial partitioning determines actual leaf count) for i in 0..500 { @@ -1111,21 +1176,22 @@ mod tests { builder.write_index(test_store.as_ref()).await.unwrap(); // Load index and verify - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await - .expect("Failed to load GeoIndex"); + .expect("Failed to load BkdTree"); // Validate tree structure - validate_bkd_tree(&index).await - .expect(&format!("BKD tree validation failed for batches_per_file={}", batches_per_file)); - + validate_bkd_tree(&index).await.expect(&format!( + "BKD tree validation failed for batches_per_file={}", + batches_per_file + )); } } #[tokio::test] async fn test_geo_intersects_query_correctness_various_configs() { use crate::metrics::NoOpMetricsCollector; - + // Test query correctness with different configurations let configs = vec![ (10, 1), // Small leaves, one per file @@ -1135,14 +1201,15 @@ mod tests { for (max_points_per_leaf, batches_per_file) in configs { let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf, batches_per_file, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); // Create a grid of points for x in 0..20 { @@ -1154,34 +1221,40 @@ mod tests { // Write and load index builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - + // Validate tree structure - validate_bkd_tree(&index).await - .expect(&format!("BKD tree validation failed for max_points_per_leaf={}, batches_per_file={}", max_points_per_leaf, batches_per_file)); + validate_bkd_tree(&index).await.expect(&format!( + "BKD tree validation failed for max_points_per_leaf={}, batches_per_file={}", + max_points_per_leaf, batches_per_file + )); // Query: bbox [5, 5, 10, 10] should return points in that region let query = GeoQuery::Intersects(5.0, 5.0, 10.0, 10.0); - + let metrics = NoOpMetricsCollector {}; let result = index.search(&query, &metrics).await.unwrap(); - + // Should find points (5,5) to (10,10) = 6x6 = 36 points match result { crate::scalar::SearchResult::Exact(row_ids) => { assert_eq!(row_ids.len().unwrap_or(0), 36, "Expected 36 points for config max_points_per_leaf={}, batches_per_file={}, got {}", max_points_per_leaf, batches_per_file, row_ids.len().unwrap_or(0)); - + // Verify correct row IDs for x in 5..=10 { for y in 5..=10 { let expected_row_id = (x * 20 + y) as u64; - assert!(row_ids.contains(expected_row_id), - "Missing row_id {} for point ({}, {})", - expected_row_id, x, y); + assert!( + row_ids.contains(expected_row_id), + "Missing row_id {} for point ({}, {})", + expected_row_id, + x, + y + ); } } } @@ -1194,14 +1267,14 @@ mod tests { async fn test_geo_intersects_single_leaf() { // Edge case: all points fit in single leaf let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 1000, batches_per_file: 5, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create only 50 points for i in 0..50 { @@ -1209,18 +1282,19 @@ mod tests { } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // Should have only 1 leaf assert_eq!(index.bkd_tree.num_leaves, 1); assert_eq!(index.bkd_tree.nodes.len(), 1); - + // Validate tree structure - validate_bkd_tree(&index).await + validate_bkd_tree(&index) + .await .expect("BKD tree validation failed for single leaf test"); - + // Single leaf should be in file 0 at offset 0 match &index.bkd_tree.nodes[0] { BKDNode::Leaf(leaf) => { @@ -1236,14 +1310,14 @@ mod tests { async fn test_geo_intersects_many_small_leaves() { // Stress test: many small leaves, test file grouping let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { - max_points_per_leaf: 5, // Very small leaves - batches_per_file: 3, // Few batches per file + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 5, // Very small leaves + batches_per_file: 3, // Few batches per file }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create 100 points (BKD spatial partitioning determines actual leaf count) for i in 0..100 { @@ -1251,15 +1325,16 @@ mod tests { } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // BKD tree creates more leaves due to spatial partitioning assert_eq!(index.bkd_tree.num_leaves, 32); - + // Validate tree structure - validate_bkd_tree(&index).await + validate_bkd_tree(&index) + .await .expect("BKD tree validation failed for many small leaves test"); // 32 leaves / 3 batches_per_file = 11 files (ceil) @@ -1270,7 +1345,7 @@ mod tests { .filter_map(|n| n.as_leaf()) .map(|l| l.file_id) .collect(); - + assert_eq!(file_ids.len(), 11); // Verify row offsets are cumulative within each file @@ -1282,15 +1357,15 @@ mod tests { .collect(); for file_id in 0..11 { - let leaves_in_file: Vec<_> = leaves - .iter() - .filter(|l| l.file_id == file_id) - .collect(); - + let leaves_in_file: Vec<_> = leaves.iter().filter(|l| l.file_id == file_id).collect(); + let mut expected_offset = 0u64; for leaf in leaves_in_file { - assert_eq!(leaf.row_offset, expected_offset, - "Incorrect offset in file {}", file_id); + assert_eq!( + leaf.row_offset, expected_offset, + "Incorrect offset in file {}", + file_id + ); expected_offset += leaf.num_rows; } } @@ -1300,7 +1375,7 @@ mod tests { async fn test_geo_index_data_integrity_after_serialization() { // Verify every point written can be read back exactly use crate::metrics::NoOpMetricsCollector; - + let configs = vec![ (10, 1), // Small leaves, one per file (50, 3), // Medium leaves, few per file @@ -1309,14 +1384,15 @@ mod tests { for (max_points_per_leaf, batches_per_file) in configs { let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf, batches_per_file, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); // Create test points with known values let mut original_points = Vec::new(); @@ -1332,7 +1408,7 @@ mod tests { builder.write_index(test_store.as_ref()).await.unwrap(); // Load index - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); @@ -1343,7 +1419,7 @@ mod tests { for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { // Load the leaf data let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); - + // Extract points from this leaf let x_array = leaf_data .column(0) @@ -1364,21 +1440,30 @@ mod tests { } // Verify all points were recovered - assert_eq!(recovered_points.len(), original_points.len(), - "Lost points with config max_points_per_leaf={}, batches_per_file={}", - max_points_per_leaf, batches_per_file); + assert_eq!( + recovered_points.len(), + original_points.len(), + "Lost points with config max_points_per_leaf={}, batches_per_file={}", + max_points_per_leaf, + batches_per_file + ); // Verify each point matches exactly for (original_x, original_y, row_id) in &original_points { - let (recovered_x, recovered_y) = recovered_points.get(row_id) + let (recovered_x, recovered_y) = recovered_points + .get(row_id) .expect(&format!("Missing row_id {} in recovered data", row_id)); - - assert_eq!(*recovered_x, *original_x, - "X coordinate mismatch for row_id {}: expected {}, got {}", - row_id, original_x, recovered_x); - assert_eq!(*recovered_y, *original_y, - "Y coordinate mismatch for row_id {}: expected {}, got {}", - row_id, original_y, recovered_y); + + assert_eq!( + *recovered_x, *original_x, + "X coordinate mismatch for row_id {}: expected {}, got {}", + row_id, original_x, recovered_x + ); + assert_eq!( + *recovered_y, *original_y, + "Y coordinate mismatch for row_id {}: expected {}, got {}", + row_id, original_y, recovered_y + ); } } } @@ -1387,14 +1472,14 @@ mod tests { async fn test_geo_index_no_duplicate_points() { // Ensure no points are duplicated during write/read let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { - max_points_per_leaf: 7, // Odd number to test edge cases + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 7, // Odd number to test edge cases batches_per_file: 3, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create 100 unique points for i in 0..100 { @@ -1402,7 +1487,7 @@ mod tests { } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); @@ -1431,46 +1516,47 @@ mod tests { // Should be exactly 0..99 for i in 0..100 { - assert!(unique_row_ids.contains(&(i as u64)), - "Missing row_id {}", i); + assert!(unique_row_ids.contains(&(i as u64)), "Missing row_id {}", i); } } #[tokio::test] async fn test_geo_intersects_lazy_loading() { use crate::metrics::NoOpMetricsCollector; - + // Test that leaves are loaded lazily (not all at once) let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 5, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create 100 points in a grid for x in 0..10 { for y in 0..10 { - builder.points.push((x as f64 * 10.0, y as f64 * 10.0, (x * 10 + y) as u64)); + builder + .points + .push((x as f64 * 10.0, y as f64 * 10.0, (x * 10 + y) as u64)); } } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // Query a small region - should only load relevant leaves, not all of them let query = GeoQuery::Intersects(0.0, 0.0, 20.0, 20.0); let metrics = NoOpMetricsCollector {}; - + // Find which leaves would be touched let query_bbox = [0.0, 0.0, 20.0, 20.0]; let intersecting_leaves = index.bkd_tree.find_intersecting_leaves(query_bbox).unwrap(); - + // Verify we're not loading ALL leaves for a small query assert!( intersecting_leaves.len() < index.bkd_tree.num_leaves as usize, @@ -1479,10 +1565,10 @@ mod tests { intersecting_leaves.len(), index.bkd_tree.num_leaves ); - + // Execute the query let result = index.search(&query, &metrics).await.unwrap(); - + // Manually verify correctness: which points SHOULD be in bbox [0, 0, 20, 20]? let mut expected_row_ids = std::collections::HashSet::new(); for x in 0..10 { @@ -1495,32 +1581,42 @@ mod tests { } } } - + // Verify results match our manual calculation match result { crate::scalar::SearchResult::Exact(row_ids) => { let actual_count = row_ids.len().unwrap_or(0) as usize; - - assert_eq!(actual_count, expected_row_ids.len(), - "Expected {} points, got {}", - expected_row_ids.len(), actual_count); - + + assert_eq!( + actual_count, + expected_row_ids.len(), + "Expected {} points, got {}", + expected_row_ids.len(), + actual_count + ); + // Verify each returned row_id is in our expected set if let Some(iter) = row_ids.row_ids() { for row_addr in iter { let row_id = u64::from(row_addr); - assert!(expected_row_ids.contains(&row_id), - "Unexpected row_id {} in results", row_id); + assert!( + expected_row_ids.contains(&row_id), + "Unexpected row_id {} in results", + row_id + ); } } - + // Verify we didn't miss any expected points if let Some(iter) = row_ids.row_ids() { - let found_ids: std::collections::HashSet = + let found_ids: std::collections::HashSet = iter.map(|addr| u64::from(addr)).collect(); for expected_id in &expected_row_ids { - assert!(found_ids.contains(expected_id), - "Missing expected row_id {} in results", expected_id); + assert!( + found_ids.contains(expected_id), + "Missing expected row_id {} in results", + expected_id + ); } } } @@ -1532,54 +1628,46 @@ mod tests { async fn test_geo_intersects_invalid_coordinates() { // Test that NaN and infinity coordinates are rejected (like Lucene does) let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 5, }; - + // Test NaN in X coordinate - let mut builder = GeoIndexBuilder::try_new( - params.clone(), - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); builder.points.push((10.0, 10.0, 1)); builder.points.push((f64::NAN, 20.0, 2)); let result = builder.write_index(test_store.as_ref()).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("NaN")); - + // Test NaN in Y coordinate - let mut builder = GeoIndexBuilder::try_new( - params.clone(), - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); builder.points.push((10.0, 10.0, 1)); builder.points.push((20.0, f64::NAN, 2)); let result = builder.write_index(test_store.as_ref()).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("NaN")); - + // Test Infinity in X coordinate - let mut builder = GeoIndexBuilder::try_new( - params.clone(), - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); builder.points.push((10.0, 10.0, 1)); builder.points.push((f64::INFINITY, 20.0, 2)); let result = builder.write_index(test_store.as_ref()).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("Infinite")); - + // Test Negative Infinity in Y coordinate - let mut builder = GeoIndexBuilder::try_new( - params.clone(), - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); builder.points.push((10.0, 10.0, 1)); builder.points.push((20.0, f64::NEG_INFINITY, 2)); let result = builder.write_index(test_store.as_ref()).await; @@ -1590,34 +1678,31 @@ mod tests { #[tokio::test] async fn test_geo_intersects_out_of_bounds_queries() { use crate::metrics::NoOpMetricsCollector; - + // Test queries completely outside the data bounds let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 5, }; - - let mut builder = GeoIndexBuilder::try_new( - params, - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); - + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + // Add points in region 0-100, 0-100 for i in 0..50 { builder.points.push((i as f64, i as f64, i as u64)); } - + builder.write_index(test_store.as_ref()).await.unwrap(); - - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - + let metrics = NoOpMetricsCollector {}; - + // Query completely out of bounds (far away) let query = GeoQuery::Intersects(1000.0, 1000.0, 2000.0, 2000.0); let result = index.search(&query, &metrics).await.unwrap(); @@ -1628,7 +1713,7 @@ mod tests { } _ => panic!("Expected Exact search result"), } - + // Query in negative space (if data is all positive) let query = GeoQuery::Intersects(-500.0, -500.0, -100.0, -100.0); let result = index.search(&query, &metrics).await.unwrap(); @@ -1644,21 +1729,18 @@ mod tests { #[tokio::test] async fn test_geo_intersects_extreme_coordinates() { use crate::metrics::NoOpMetricsCollector; - + // Test with extremely large and small (but valid) coordinates let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 5, }; - - let mut builder = GeoIndexBuilder::try_new( - params, - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); - + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + // Add points with extreme coordinates builder.points.push((1e10, 1e10, 1)); builder.points.push((-1e10, -1e10, 2)); @@ -1666,34 +1748,36 @@ mod tests { builder.points.push((-1e-10, -1e-10, 4)); builder.points.push((f64::MAX / 2.0, f64::MAX / 2.0, 5)); builder.points.push((f64::MIN / 2.0, f64::MIN / 2.0, 6)); - + builder.write_index(test_store.as_ref()).await.unwrap(); - - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - + let metrics = NoOpMetricsCollector {}; - + // Query for large positive values let query = GeoQuery::Intersects(1e10 - 1.0, 1e10 - 1.0, 1e10 + 1.0, 1e10 + 1.0); let result = index.search(&query, &metrics).await.unwrap(); match result { crate::scalar::SearchResult::Exact(row_ids) => { - let actual_ids: Vec = row_ids.row_ids() + let actual_ids: Vec = row_ids + .row_ids() .map(|iter| iter.map(|addr| u64::from(addr)).collect()) .unwrap_or_default(); assert!(actual_ids.contains(&1)); } _ => panic!("Expected Exact search result"), } - + // Query for large negative values let query = GeoQuery::Intersects(-1e10 - 1.0, -1e10 - 1.0, -1e10 + 1.0, -1e10 + 1.0); let result = index.search(&query, &metrics).await.unwrap(); match result { crate::scalar::SearchResult::Exact(row_ids) => { - let actual_ids: Vec = row_ids.row_ids() + let actual_ids: Vec = row_ids + .row_ids() .map(|iter| iter.map(|addr| u64::from(addr)).collect()) .unwrap_or_default(); assert!(actual_ids.contains(&2)); @@ -1705,34 +1789,31 @@ mod tests { #[tokio::test] async fn test_geo_intersects_huge_query_bbox() { use crate::metrics::NoOpMetricsCollector; - + // Test query larger than all data bounds let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 5, }; - - let mut builder = GeoIndexBuilder::try_new( - params, - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); - + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + // Add points in small region for i in 0..20 { builder.points.push((i as f64, i as f64, i as u64)); } - + builder.write_index(test_store.as_ref()).await.unwrap(); - - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - + let metrics = NoOpMetricsCollector {}; - + // Query that encompasses ALL data let query = GeoQuery::Intersects(-1000.0, -1000.0, 1000.0, 1000.0); let result = index.search(&query, &metrics).await.unwrap(); @@ -1749,41 +1830,39 @@ mod tests { #[tokio::test] async fn test_geo_intersects_zero_size_query() { use crate::metrics::NoOpMetricsCollector; - + // Test query box with zero width/height (point query) let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 5, }; - - let mut builder = GeoIndexBuilder::try_new( - params, - DataType::Struct(Vec::::new().into()), - ) - .unwrap(); - + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + // Add some points builder.points.push((10.0, 10.0, 1)); builder.points.push((10.0, 10.0, 2)); // Duplicate at same location builder.points.push((20.0, 20.0, 3)); builder.points.push((30.0, 30.0, 4)); - + builder.write_index(test_store.as_ref()).await.unwrap(); - - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); - + let metrics = NoOpMetricsCollector {}; - + // Zero-width query (line) let query = GeoQuery::Intersects(10.0, 10.0, 10.0, 20.0); let result = index.search(&query, &metrics).await.unwrap(); match result { crate::scalar::SearchResult::Exact(row_ids) => { - let actual_ids: Vec = row_ids.row_ids() + let actual_ids: Vec = row_ids + .row_ids() .map(|iter| iter.map(|addr| u64::from(addr)).collect()) .unwrap_or_default(); // Should find points at (10, 10) @@ -1792,13 +1871,14 @@ mod tests { } _ => panic!("Expected Exact search result"), } - + // Zero-height query (line) let query = GeoQuery::Intersects(10.0, 10.0, 20.0, 10.0); let result = index.search(&query, &metrics).await.unwrap(); match result { crate::scalar::SearchResult::Exact(row_ids) => { - let actual_ids: Vec = row_ids.row_ids() + let actual_ids: Vec = row_ids + .row_ids() .map(|iter| iter.map(|addr| u64::from(addr)).collect()) .unwrap_or_default(); // Should find points at (10, 10) @@ -1807,13 +1887,14 @@ mod tests { } _ => panic!("Expected Exact search result"), } - + // Point query (zero width and height) let query = GeoQuery::Intersects(20.0, 20.0, 20.0, 20.0); let result = index.search(&query, &metrics).await.unwrap(); match result { crate::scalar::SearchResult::Exact(row_ids) => { - let actual_ids: Vec = row_ids.row_ids() + let actual_ids: Vec = row_ids + .row_ids() .map(|iter| iter.map(|addr| u64::from(addr)).collect()) .unwrap_or_default(); // Should find point at (20, 20) @@ -1826,10 +1907,10 @@ mod tests { #[tokio::test] async fn test_geo_intersects_large_scale_lazy_loading() { - use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; use std::sync::atomic::{AtomicUsize, Ordering}; - + // Custom metrics collector to track I/O operations struct LoadTracker { part_loads: AtomicUsize, @@ -1842,16 +1923,16 @@ mod tests { fn record_index_loads(&self, _num_indices: usize) {} fn record_comparisons(&self, _num_comparisons: usize) {} } - + let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { - max_points_per_leaf: 1000, // Realistic leaf size + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 1000, // Realistic leaf size batches_per_file: 10, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create 1 million random points across a 1000x1000 grid let mut rng = StdRng::seed_from_u64(42); @@ -1862,15 +1943,17 @@ mod tests { } builder.write_index(test_store.as_ref()).await.unwrap(); - - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // Run 100 random queries with load tracking - let metrics = LoadTracker { part_loads: AtomicUsize::new(0) }; + let metrics = LoadTracker { + part_loads: AtomicUsize::new(0), + }; let mut total_leaves_touched = 0; - + for _i in 0..100 { // Random query bbox with varying sizes (1x1 to 50x50 regions in 1000x1000 space) let width = rng.random_range(1.0..50.0); @@ -1879,29 +1962,30 @@ mod tests { let min_y = rng.random_range(0.0..(1000.0 - height)); let max_x = min_x + width; let max_y = min_y + height; - + let query = GeoQuery::Intersects(min_x, min_y, max_x, max_y); - + // Count leaves touched let query_bbox = [min_x, min_y, max_x, max_y]; let intersecting_leaves = index.bkd_tree.find_intersecting_leaves(query_bbox).unwrap(); total_leaves_touched += intersecting_leaves.len(); - + // Execute query let _result = index.search(&query, &metrics).await.unwrap(); } - + let total_io_ops = metrics.part_loads.load(Ordering::Relaxed); - + // CRITICAL: Verify lazy loading is working! // We should NOT load all leaves - only the ones intersecting query bboxes assert!( total_io_ops < index.bkd_tree.num_leaves as usize, "āŒ LAZY LOADING FAILED: Loaded {} leaves but index only has {} leaves! \ Should load much fewer than total.", - total_io_ops, index.bkd_tree.num_leaves + total_io_ops, + index.bkd_tree.num_leaves ); - + // Verify lazy loading is effective (< 10% of leaves touched on average per query) let avg_leaves_touched = total_leaves_touched as f64 / 100.0; let total_leaves = index.bkd_tree.num_leaves as f64; @@ -1916,24 +2000,26 @@ mod tests { async fn test_geo_index_points_in_correct_leaves() { // Verify points are in leaves with correct bounding boxes let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 5, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create a grid of points for x in 0..20 { for y in 0..20 { - builder.points.push((x as f64, y as f64, (x * 20 + y) as u64)); + builder + .points + .push((x as f64, y as f64, (x * 20 + y) as u64)); } } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); @@ -1943,7 +2029,7 @@ mod tests { for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); - + let x_array = leaf_data .column(0) .as_primitive::(); @@ -1954,14 +2040,22 @@ mod tests { for i in 0..leaf_data.num_rows() { let x = x_array.value(i); let y = y_array.value(i); - + // Point must be within leaf's bounding box - assert!(x >= leaf.bounds[0] && x <= leaf.bounds[2], - "Point x={} outside leaf bounds [{}, {}]", - x, leaf.bounds[0], leaf.bounds[2]); - assert!(y >= leaf.bounds[1] && y <= leaf.bounds[3], - "Point y={} outside leaf bounds [{}, {}]", - y, leaf.bounds[1], leaf.bounds[3]); + assert!( + x >= leaf.bounds[0] && x <= leaf.bounds[2], + "Point x={} outside leaf bounds [{}, {}]", + x, + leaf.bounds[0], + leaf.bounds[2] + ); + assert!( + y >= leaf.bounds[1] && y <= leaf.bounds[3], + "Point y={} outside leaf bounds [{}, {}]", + y, + leaf.bounds[1], + leaf.bounds[3] + ); } } } @@ -1970,16 +2064,16 @@ mod tests { async fn test_geo_index_duplicate_coordinates_different_row_ids() { // Test that duplicate coordinates with different row_ids are all stored correctly use crate::metrics::NoOpMetricsCollector; - + let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 10, batches_per_file: 3, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create multiple points at the same coordinates with different row_ids // This simulates real-world scenarios where multiple records have the same location @@ -1987,14 +2081,14 @@ mod tests { // 5 points at location (10.0, 10.0) with different row_ids builder.points.push((10.0, 10.0, i as u64)); } - + // Add some other points at different locations for i in 20..50 { builder.points.push((i as f64, i as f64, i as u64)); } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); @@ -2002,17 +2096,23 @@ mod tests { let query = GeoQuery::Intersects(9.0, 9.0, 11.0, 11.0); let metrics = NoOpMetricsCollector {}; let result = index.search(&query, &metrics).await.unwrap(); - + match result { crate::scalar::SearchResult::Exact(row_ids) => { // Should find all 20 points at (10.0, 10.0) - assert_eq!(row_ids.len().unwrap_or(0), 20, - "Expected 20 duplicate coordinate entries"); - + assert_eq!( + row_ids.len().unwrap_or(0), + 20, + "Expected 20 duplicate coordinate entries" + ); + // Verify all row_ids 0..19 are present for i in 0..20 { - assert!(row_ids.contains(i as u64), - "Missing row_id {} for duplicate coordinate", i); + assert!( + row_ids.contains(i as u64), + "Missing row_id {} for duplicate coordinate", + i + ); } } _ => panic!("Expected Exact search result"), @@ -2031,8 +2131,12 @@ mod tests { } } - assert_eq!(all_row_ids.len(), 50, "Should store all 50 points including duplicates"); - + assert_eq!( + all_row_ids.len(), + 50, + "Should store all 50 points including duplicates" + ); + let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); assert_eq!(unique_row_ids.len(), 50, "All 50 row_ids should be unique"); } @@ -2041,16 +2145,16 @@ mod tests { async fn test_geo_index_many_duplicates_at_same_location() { // Test with many duplicate coordinates to ensure they don't cause issues use crate::metrics::NoOpMetricsCollector; - + let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 20, batches_per_file: 5, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create 100 points all at the exact same location for i in 0..100 { @@ -2058,28 +2162,35 @@ mod tests { } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // Validate tree structure - validate_bkd_tree(&index).await + validate_bkd_tree(&index) + .await .expect("BKD tree validation failed with many duplicates"); // Query should return all 100 points let query = GeoQuery::Intersects(49.0, 49.0, 51.0, 51.0); let metrics = NoOpMetricsCollector {}; let result = index.search(&query, &metrics).await.unwrap(); - + match result { crate::scalar::SearchResult::Exact(row_ids) => { - assert_eq!(row_ids.len().unwrap_or(0), 100, - "Expected all 100 duplicate coordinate entries"); - + assert_eq!( + row_ids.len().unwrap_or(0), + 100, + "Expected all 100 duplicate coordinate entries" + ); + // Verify all row_ids are present for i in 0..100 { - assert!(row_ids.contains(i as u64), - "Missing row_id {} in duplicate set", i); + assert!( + row_ids.contains(i as u64), + "Missing row_id {} in duplicate set", + i + ); } } _ => panic!("Expected Exact search result"), @@ -2090,24 +2201,19 @@ mod tests { async fn test_geo_index_duplicates_across_multiple_locations() { // Test duplicates at multiple different locations use crate::metrics::NoOpMetricsCollector; - + let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 15, batches_per_file: 3, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create clusters of duplicates at different locations - let locations = vec![ - (10.0, 10.0), - (20.0, 20.0), - (30.0, 30.0), - (40.0, 40.0), - ]; + let locations = vec![(10.0, 10.0), (20.0, 20.0), (30.0, 30.0), (40.0, 40.0)]; let mut row_id = 0u64; for (x, y) in &locations { @@ -2127,12 +2233,13 @@ mod tests { let total_points = row_id; builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // Validate tree structure - validate_bkd_tree(&index).await + validate_bkd_tree(&index) + .await .expect("BKD tree validation failed with duplicates at multiple locations"); let metrics = NoOpMetricsCollector {}; @@ -2141,18 +2248,27 @@ mod tests { for (i, (x, y)) in locations.iter().enumerate() { let query = GeoQuery::Intersects(x - 1.0, y - 1.0, x + 1.0, y + 1.0); let result = index.search(&query, &metrics).await.unwrap(); - + match result { crate::scalar::SearchResult::Exact(row_ids) => { - assert_eq!(row_ids.len().unwrap_or(0), 10, - "Expected 10 duplicates at location {}", i); - + assert_eq!( + row_ids.len().unwrap_or(0), + 10, + "Expected 10 duplicates at location {}", + i + ); + // Verify the correct row_ids for this cluster let expected_start = (i * 10) as u64; let expected_end = expected_start + 10; for expected_id in expected_start..expected_end { - assert!(row_ids.contains(expected_id), - "Missing row_id {} in cluster at ({}, {})", expected_id, x, y); + assert!( + row_ids.contains(expected_id), + "Missing row_id {} in cluster at ({}, {})", + expected_id, + x, + y + ); } } _ => panic!("Expected Exact search result"), @@ -2172,24 +2288,28 @@ mod tests { } } - assert_eq!(all_row_ids.len(), total_points as usize, - "Should store all {} points including duplicates", total_points); + assert_eq!( + all_row_ids.len(), + total_points as usize, + "Should store all {} points including duplicates", + total_points + ); } #[tokio::test] async fn test_geo_index_duplicate_handling_with_leaf_splits() { // Test that duplicates are handled correctly when they cause leaf splits use crate::metrics::NoOpMetricsCollector; - + let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { - max_points_per_leaf: 10, // Small leaf size to force splits + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, // Small leaf size to force splits batches_per_file: 3, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); // Create 50 points at the same location - should span multiple leaves for i in 0..50 { @@ -2197,28 +2317,35 @@ mod tests { } builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // Validate tree structure - validate_bkd_tree(&index).await + validate_bkd_tree(&index) + .await .expect("BKD tree validation failed with duplicate-induced splits"); // Query should return all 50 points let query = GeoQuery::Intersects(24.0, 24.0, 26.0, 26.0); let metrics = NoOpMetricsCollector {}; let result = index.search(&query, &metrics).await.unwrap(); - + match result { crate::scalar::SearchResult::Exact(row_ids) => { - assert_eq!(row_ids.len().unwrap_or(0), 50, - "Expected all 50 duplicate entries after leaf splits"); - + assert_eq!( + row_ids.len().unwrap_or(0), + 50, + "Expected all 50 duplicate entries after leaf splits" + ); + // Verify all row_ids are present for i in 0..50 { - assert!(row_ids.contains(i as u64), - "Missing row_id {} after leaf split", i); + assert!( + row_ids.contains(i as u64), + "Missing row_id {} after leaf split", + i + ); } } _ => panic!("Expected Exact search result"), @@ -2229,18 +2356,18 @@ mod tests { async fn test_geo_index_mixed_duplicates_and_unique_points() { // Realistic test: mix of unique points and duplicates use crate::metrics::NoOpMetricsCollector; - use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; - + use rand::{Rng, SeedableRng}; + let test_store = create_test_store(); - - let params = GeoIndexBuilderParams { + + let params = BkdTreeBuilderParams { max_points_per_leaf: 20, batches_per_file: 5, }; - let mut builder = GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())) - .unwrap(); + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); let mut rng = StdRng::seed_from_u64(123); let mut row_id = 0u64; @@ -2270,12 +2397,13 @@ mod tests { let total_points = 170; builder.write_index(test_store.as_ref()).await.unwrap(); - let index = GeoIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) .await .unwrap(); // Validate tree structure - validate_bkd_tree(&index).await + validate_bkd_tree(&index) + .await .expect("BKD tree validation failed with mixed duplicates and unique points"); let metrics = NoOpMetricsCollector {}; @@ -2283,16 +2411,22 @@ mod tests { // Query the duplicate cluster let query = GeoQuery::Intersects(49.0, 49.0, 51.0, 51.0); let result = index.search(&query, &metrics).await.unwrap(); - + match result { crate::scalar::SearchResult::Exact(row_ids) => { - assert_eq!(row_ids.len().unwrap_or(0), 20, - "Expected 20 duplicate entries at (50, 50)"); - + assert_eq!( + row_ids.len().unwrap_or(0), + 20, + "Expected 20 duplicate entries at (50, 50)" + ); + // Verify the duplicate row_ids (100..119) for expected_id in 100..120 { - assert!(row_ids.contains(expected_id as u64), - "Missing row_id {} in duplicate cluster", expected_id); + assert!( + row_ids.contains(expected_id as u64), + "Missing row_id {} in duplicate cluster", + expected_id + ); } } _ => panic!("Expected Exact search result"), @@ -2311,13 +2445,19 @@ mod tests { } } - assert_eq!(all_row_ids.len(), total_points, - "Should store all {} points", total_points); - + assert_eq!( + all_row_ids.len(), + total_points, + "Should store all {} points", + total_points + ); + // Verify all row_ids are unique (no double-counting) let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); - assert_eq!(unique_row_ids.len(), total_points, - "All row_ids should be unique"); + assert_eq!( + unique_row_ids.len(), + total_points, + "All row_ids should be unique" + ); } } - diff --git a/rust/lance-index/src/scalar/geo/mod.rs b/rust/lance-index/src/scalar/geo/mod.rs index ddf26b02724..a998c156784 100644 --- a/rust/lance-index/src/scalar/geo/mod.rs +++ b/rust/lance-index/src/scalar/geo/mod.rs @@ -4,12 +4,11 @@ //! Geographic indexing module //! //! This module contains implementations for spatial/geographic indexing: -//! - BKD Tree: Block K-Dimensional tree for efficient spatial partitioning -//! - Geo Index: Geographic index built on top of BKD trees for GeoArrow data +//! - BKD: Block K-Dimensional tree for efficient spatial partitioning (core data structure) +//! - BkdTree: Geographic index built on top of BKD trees for GeoArrow data pub mod bkd; -pub mod geoindex; +pub mod bkdtree; pub use bkd::*; -pub use geoindex::*; - +pub use bkdtree::*; diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 016e92ee438..b083f7fb6f9 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -15,7 +15,7 @@ use crate::{ frag_reuse::FragReuseIndex, scalar::{ bitmap::BitmapIndexPlugin, bloomfilter::BloomFilterIndexPlugin, btree::BTreeIndexPlugin, - expression::ScalarQueryParser, geo::GeoIndexPlugin, inverted::InvertedIndexPlugin, + expression::ScalarQueryParser, geo::BkdTreePlugin, inverted::InvertedIndexPlugin, json::JsonIndexPlugin, label_list::LabelListIndexPlugin, ngram::NGramIndexPlugin, zonemap::ZoneMapIndexPlugin, CreatedIndex, IndexStore, ScalarIndex, }, @@ -201,7 +201,7 @@ impl ScalarIndexPluginRegistry { registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); - registry.add_plugin::(); + registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index c9be843dd5e..06d94e201d8 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -152,7 +152,7 @@ impl<'a> CreateIndexBuilder<'a> { | IndexType::ZoneMap | IndexType::BloomFilter | IndexType::LabelList - | IndexType::Geo, + | IndexType::BkdTree, LANCE_SCALAR_INDEX, ) => { let base_params = ScalarIndexParams::for_builtin(self.index_type.try_into()?); diff --git a/test_geoarrow_geo_index.py b/test_geoarrow_geo_index.py deleted file mode 100644 index 67538437a22..00000000000 --- a/test_geoarrow_geo_index.py +++ /dev/null @@ -1,301 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for GeoArrow Point geo index functionality in Lance. - -This script tests: -1. Creating GeoArrow Point data -2. Writing to Lance dataset -3. Creating a geo index on GeoArrow Point column -4. Querying with spatial filters -5. Verifying the geo index is used -""" - - -import numpy as np -import pyarrow as pa -import lance -import os -import shutil -import logging -from geoarrow.pyarrow import point - -# Enable Rust logging -os.environ['RUST_LOG'] = 'lance_index=debug' -logging.basicConfig(level=logging.DEBUG) - - -def main(): - print("šŸŒ Testing GeoArrow Point Geo Index in Lance") - print("=" * 50) - - - # Clean slate - dataset_path = "/Users/jay.narale/work/Uber/geo_index_test" - if os.path.exists(dataset_path): - shutil.rmtree(dataset_path) - print(f"āœ… Cleaned up existing dataset: {dataset_path}") - - - # Step 1: Create GeoArrow Point data with enough points to test tree structure - print("\nšŸ”µ Step 1: Creating GeoArrow Point data (5000+ points)") - - # Generate random points across the US - # US bounding box approximately: lng [-125, -65], lat [25, 50] - np.random.seed(42) # For reproducibility - num_points = 5000 - - lng_vals = np.random.uniform(-125, -65, num_points) - lat_vals = np.random.uniform(25, 50, num_points) - - # Add some known cities at the beginning for testing - known_cities = [ - {"id": 1, "city": "San Francisco", "lng": -122.4194, "lat": 37.7749, "population": 883305}, - {"id": 2, "city": "Los Angeles", "lng": -118.2437, "lat": 34.0522, "population": 3898747}, - {"id": 3, "city": "New York", "lng": -74.0060, "lat": 40.7128, "population": 8336817}, - {"id": 4, "city": "Chicago", "lng": -87.6298, "lat": 41.8781, "population": 2746388}, - {"id": 5, "city": "Houston", "lng": -95.3698, "lat": 29.7604, "population": 2304580}, - ] - - # Replace first 5 points with known cities - for i, city in enumerate(known_cities): - lng_vals[i] = city["lng"] - lat_vals[i] = city["lat"] - - start_location = point().from_geobuffers(None, lng_vals, lat_vals) - - # Create IDs and city names - ids = list(range(1, num_points + 1)) - cities = [known_cities[i]["city"] if i < len(known_cities) else f"Point_{i+1}" - for i in range(num_points)] - populations = [known_cities[i]["population"] if i < len(known_cities) else np.random.randint(10000, 1000000) - for i in range(num_points)] - - table = pa.table({ - "id": ids, - "city": cities, - "start_location": start_location, - "population": populations - }) - - - print(f"āœ… Created GeoArrow Point data with {num_points} points") - print("šŸ“Š Table schema:") - print(table.schema) - print(f"šŸ“ Point column type: {table.schema.field('start_location').type}") - print(f"šŸ“ Point column metadata: {table.schema.field('start_location').metadata}") - print(f"šŸ“ Known cities: {[c['city'] for c in known_cities]}") - - - # Step 2: Write to Lance dataset - print("\nšŸ”µ Step 2: Writing to Lance dataset") - try: - geo_ds = lance.write_dataset(table, dataset_path) - print("āœ… Successfully wrote GeoArrow data to Lance dataset") - - - # Verify data was written correctly - loaded_table = geo_ds.to_table() - print(f"šŸ“Š Dataset has {len(loaded_table)} rows") - print("šŸ“Š Dataset schema:") - print(loaded_table.schema) - - - except Exception as e: - print(f"āŒ Failed to write dataset: {e}") - return - - - # Step 3: Create geo index - print("\nšŸ”µ Step 3: Creating geo index on GeoArrow Point column") - try: - geo_ds.create_scalar_index(column="start_location", index_type="GEO") - print("āœ… Successfully created geo index") - - - # Check what indexes exist - indexes = geo_ds.list_indices() - print("šŸ“Š Available indexes:") - for idx in indexes: - print(f" - {idx}") - - - except Exception as e: - print(f"āŒ Failed to create geo index: {e}") - return - - - # Step 4: Test st_intersects spatial query with broad bbox (both cities) - print("\nšŸ”µ Step 4: Testing st_intersects spatial query with broad bbox (both cities)") - - - - - # First, run EXPLAIN ANALYZE to see the execution plan - explain_sql = """ - EXPLAIN ANALYZE SELECT id, city, population - FROM dataset - WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) - """ - - - print("\nšŸ“‹ Running EXPLAIN ANALYZE...") - explain_query = geo_ds.sql(explain_sql).build() - explain_result = explain_query.to_batch_records() - - - if explain_result: - explain_table = pa.Table.from_batches(explain_result) - print("šŸ” EXPLAIN ANALYZE Result:") - print(f"Schema: {explain_table.schema}") - print(f"Rows: {len(explain_table)}") - - - # Print the execution plan - for i in range(len(explain_table)): - for j, column in enumerate(explain_table.columns): - col_name = explain_table.schema.field(j).name - value = column.to_pylist()[i] - print(f"šŸ“Š {col_name}: {value}") - - - # Check if geo index was used - if len(explain_table) > 0: - # Column 1 contains the actual plan, column 0 is just the plan type - plan_text = str(explain_table.column(1).to_pylist()[0]) - if "ScalarIndexQuery" in plan_text or "start_location_idx" in plan_text: - print("āœ… šŸŒ GEO INDEX WAS USED!") - if "start_location_idx" in plan_text: - print("āœ… šŸŒ Found geo index reference: start_location_idx") - if "ST_Intersects" in plan_text: - print("āœ… šŸŒ Spatial query detected: ST_Intersects") - # Extract performance metrics - import re - if "output_rows=" in plan_text: - rows_match = re.search(r'output_rows=(\d+)', plan_text) - if rows_match: - print(f"āœ… šŸŒ Index returned {rows_match.group(1)} rows") - if "search_time=" in plan_text: - time_match = re.search(r'search_time=([^,\]]+)', plan_text) - if time_match: - print(f"āœ… šŸŒ Index search time: {time_match.group(1)}") - else: - print("āš ļø Geo index was not detected in execution plan") - print(f"šŸ“‹ Full plan: {plan_text}") - - - # Now run the actual query and get complete results - print("\nšŸ“‹ Running actual query...") - actual_sql = """ - SELECT id, city, population - FROM dataset - WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) - """ - query = geo_ds.sql(actual_sql).build() - result = query.to_batch_records() - - - if result: - table = pa.Table.from_batches(result) - print("āœ… Query Results:") - print(f"šŸ“Š Schema: {table.schema}") - print(f"šŸ“Š Number of rows: {len(table)}") - - - # Print first few results - max_rows_to_print = min(10, len(table)) - for i in range(max_rows_to_print): - row_data = {} - for j, column in enumerate(table.columns): - col_name = table.schema.field(j).name - value = column.to_pylist()[i] - row_data[col_name] = value - print(f"šŸ“ Row {i}: {row_data}") - if len(table) > max_rows_to_print: - print(f"... and {len(table) - max_rows_to_print} more rows") - - - cities = table.column('city').to_pylist() - print(f"\nāœ… Found {len(cities)} results with broad bbox") - print(f"šŸ“Š Known cities in results: {[c for c in cities if c in ['San Francisco', 'Los Angeles', 'New York', 'Chicago', 'Houston']]}") - - # With 5000 random points and a broad western US bbox, we should get hundreds/thousands of results - assert len(cities) > 100, f"Expected many results (>100) from broad bbox, got {len(cities)}" - assert 'San Francisco' in cities, "Expected San Francisco in results" - assert 'Los Angeles' in cities, "Expected Los Angeles in results" - print(f"āœ… Verified SF and LA are in the {len(cities)} results") - else: - print("āš ļø No results returned") - - - # Step 4b: Test with tight bbox (only San Francisco) - print("\nšŸ”µ Step 4b: Testing st_intersects with tight bbox (only San Francisco)") - # SF is at (-122.4194, 37.7749), so use a tight box around it - tight_sql = """ - SELECT id, city, population - FROM dataset - WHERE st_intersects(start_location, bbox(-123, 37, -122, 38)) - """ - tight_query = geo_ds.sql(tight_sql).build() - tight_result = tight_query.to_batch_records() - - if tight_result: - tight_table = pa.Table.from_batches(tight_result) - print("āœ… Query Results:") - print(f"šŸ“Š Number of rows: {len(tight_table)}") - - # Print first few results - max_rows_to_print = min(10, len(tight_table)) - for i in range(max_rows_to_print): - row_data = {} - for j, column in enumerate(tight_table.columns): - col_name = tight_table.schema.field(j).name - value = column.to_pylist()[i] - row_data[col_name] = value - print(f"šŸ“ Row {i}: {row_data}") - if len(tight_table) > max_rows_to_print: - print(f"... and {len(tight_table) - max_rows_to_print} more rows") - - cities = tight_table.column('city').to_pylist() - print(f"\nāœ… Found {len(cities)} results with tight bbox") - known_cities_found = [c for c in cities if c in ['San Francisco', 'Los Angeles', 'New York', 'Chicago', 'Houston']] - print(f"šŸ“Š Known cities in results: {known_cities_found}") - - # The tight bbox around SF should include SF, and might include some random points - assert 'San Francisco' in cities, "Expected San Francisco in results" - assert len(known_cities_found) == 1 and known_cities_found[0] == 'San Francisco', \ - f"Expected only San Francisco from known cities, got {known_cities_found}" - print(f"āœ… Verified only SF is in the known cities, total results: {len(cities)}") - else: - print("āš ļø No results returned") - - - - - - - # Step 5: Check index files - print("\nšŸ”µ Step 5: Verifying index files") - try: - import glob - index_files = glob.glob(f"{dataset_path}/_indices/*") - print(f"šŸ“‚ Index directories: {len(index_files)}") - - - for idx_dir in index_files: - files = glob.glob(f"{idx_dir}/*") - print(f"šŸ“‚ Files in {idx_dir}:") - for f in files: - file_size = os.path.getsize(f) - print(f" - {os.path.basename(f)} ({file_size} bytes)") - - - except Exception as e: - print(f"āŒ Failed to check index files: {e}") - - - print("\nšŸŽ‰ Test completed!") - print("=" * 50) - - -if __name__ == "__main__": - main() \ No newline at end of file