diff --git a/protos/index.proto b/protos/index.proto index c6d6370f906..9e8f6ca9b7d 100644 --- a/protos/index.proto +++ b/protos/index.proto @@ -188,4 +188,6 @@ message JsonIndexDetails { string path = 1; google.protobuf.Any target_details = 2; } -message BloomFilterIndexDetails {} \ No newline at end of file +message BloomFilterIndexDetails {} + +message BkdTreeIndexDetails {} \ No newline at end of file diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index e0b7f638846..f39b48ac7e0 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2345,6 +2345,7 @@ def create_scalar_index( "LABEL_LIST", "INVERTED", "BLOOMFILTER", + "BKDTREE", ]: raise NotImplementedError( ( @@ -2392,11 +2393,27 @@ def create_scalar_index( f"INVERTED index column {column} must be string, large string" " or list of strings, but got {value_type}" ) - - if pa.types.is_duration(field_type): - raise TypeError( - f"Scalar index column {column} cannot currently be a duration" - ) + elif index_type == "BKDTREE": + # Accept struct for GeoArrow point data + if pa.types.is_struct(field_type): + field_names = [field.name for field in field_type] + if set(field_names) == {"x", "y"}: + # This is geoarrow point data - allow it + pass + else: + raise TypeError( + f"BKDTREE index column {column} must be a struct with x,y fields for point data. " + f"Got struct with fields: {field_names}" + ) + else: + raise TypeError( + f"BKDTREE index column {column} must be a struct type. " + f"Got field type: {field_type}" + ) + if pa.types.is_duration(field_type): + raise TypeError( + f"Scalar index column {column} cannot currently be a duration" + ) elif isinstance(index_type, IndexConfig): config = json.dumps(index_type.parameters) kwargs["config"] = indices.IndexConfig(index_type.index_type, config) diff --git a/python/python/tests/test_bkdtree_index.py b/python/python/tests/test_bkdtree_index.py new file mode 100644 index 00000000000..65e336bf820 --- /dev/null +++ b/python/python/tests/test_bkdtree_index.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +""" +Tests for BKD Tree spatial index on GeoArrow Point data. + +This module tests: +1. Creating GeoArrow Point data +2. Writing to Lance dataset +3. Creating a BKD tree index on GeoArrow Point column +4. Querying with spatial filters +5. Verifying the index is used in query execution +""" + +import os + +import numpy as np +import pyarrow as pa +import pytest + +import lance + +geoarrow = pytest.importorskip("geoarrow.pyarrow") + + +@pytest.fixture +def geoarrow_data(): + """Create GeoArrow Point test data with known cities and random points.""" + np.random.seed(42) + num_points = 5000 + + # Generate random points across the US + # US bounding box: lng [-125, -65], lat [25, 50] + lng_vals = np.random.uniform(-125, -65, num_points) + lat_vals = np.random.uniform(25, 50, num_points) + + # Add known cities at the beginning for testing + known_cities = [ + { + "id": 1, + "city": "San Francisco", + "lng": -122.4194, + "lat": 37.7749, + "population": 883305, + }, + { + "id": 2, + "city": "Los Angeles", + "lng": -118.2437, + "lat": 34.0522, + "population": 3898747, + }, + { + "id": 3, + "city": "New York", + "lng": -74.0060, + "lat": 40.7128, + "population": 8336817, + }, + { + "id": 4, + "city": "Chicago", + "lng": -87.6298, + "lat": 41.8781, + "population": 2746388, + }, + { + "id": 5, + "city": "Houston", + "lng": -95.3698, + "lat": 29.7604, + "population": 2304580, + }, + ] + + # Replace first 5 points with known cities + for i, city in enumerate(known_cities): + lng_vals[i] = city["lng"] + lat_vals[i] = city["lat"] + + start_location = geoarrow.point().from_geobuffers(None, lng_vals, lat_vals) + + # Create IDs and city names + ids = list(range(1, num_points + 1)) + cities = [ + known_cities[i]["city"] if i < len(known_cities) else f"Point_{i + 1}" + for i in range(num_points) + ] + populations = [ + known_cities[i]["population"] + if i < len(known_cities) + else np.random.randint(10000, 1000000) + for i in range(num_points) + ] + + table = pa.table( + { + "id": ids, + "city": cities, + "start_location": start_location, + "population": populations, + } + ) + + return table, known_cities + + +def test_write_geoarrow_to_lance(tmp_path, geoarrow_data): + """Test writing GeoArrow Point data to Lance dataset.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + + # Verify data was written correctly + loaded_table = ds.to_table() + assert len(loaded_table) == len(table) + assert loaded_table.schema.equals(table.schema) + + +def test_create_bkdtree_index(tmp_path, geoarrow_data): + """Test creating a BKD tree index on GeoArrow Point column.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + + # Create BKD tree index + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Verify index was created + indexes = ds.list_indices() + assert len(indexes) > 0 + + # Check that index files exist + index_dir = dataset_path / "_indices" + assert index_dir.exists() + index_files = list(index_dir.rglob("*")) + assert len(index_files) > 0 + + +def test_spatial_query_broad_bbox(tmp_path, geoarrow_data): + """Test spatial query with broad bounding box covering multiple cities.""" + table, known_cities = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Query with broad bbox covering western US + # Should include San Francisco and Los Angeles + sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) + """ + + query = ds.sql(sql).build() + result = query.to_batch_records() + + assert result is not None + result_table = pa.Table.from_batches(result) + + # Should get many results with random points + assert len(result_table) > 100 + + # Should include known cities in the bbox + cities = result_table.column("city").to_pylist() + assert "San Francisco" in cities + assert "Los Angeles" in cities + + +def test_spatial_query_tight_bbox(tmp_path, geoarrow_data): + """Test spatial query with tight bounding box around single city.""" + table, known_cities = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Query with tight bbox around San Francisco only + # SF is at (-122.4194, 37.7749) + sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-123, 37, -122, 38)) + """ + + query = ds.sql(sql).build() + result = query.to_batch_records() + + assert result is not None + result_table = pa.Table.from_batches(result) + + # Should include San Francisco + cities = result_table.column("city").to_pylist() + assert "San Francisco" in cities + + # From known cities, should only include San Francisco + known_cities_in_result = [ + c + for c in cities + if c in ["San Francisco", "Los Angeles", "New York", "Chicago", "Houston"] + ] + assert known_cities_in_result == ["San Francisco"] + + +def test_spatial_query_uses_index(tmp_path, geoarrow_data): + """Test that spatial queries use the BKD tree index via EXPLAIN ANALYZE.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Run EXPLAIN ANALYZE to verify index usage + explain_sql = """ + EXPLAIN ANALYZE SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-125, 30, -115, 45)) + """ + + query = ds.sql(explain_sql).build() + result = query.to_batch_records() + + assert result is not None + explain_table = pa.Table.from_batches(result) + assert len(explain_table) > 0 + + # Check if index was used in the execution plan + # The plan is in the second column + plan_text = str(explain_table.column(1).to_pylist()[0]) + + # Look for evidence of index usage + # (The exact string may vary, adjust based on actual output) + assert "ScalarIndexQuery" in plan_text or "start_location_idx" in plan_text, ( + f"Index not detected in execution plan: {plan_text}" + ) + + +def test_spatial_query_empty_result(tmp_path, geoarrow_data): + """Test spatial query with bbox that doesn't intersect any points.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + # Query with bbox outside the US (e.g., over the Pacific Ocean) + sql = """ + SELECT id, city, population + FROM dataset + WHERE st_intersects(start_location, bbox(-180, -10, -175, -5)) + """ + + query = ds.sql(sql).build() + result = query.to_batch_records() + + # Should return empty result or very small result + if result: + result_table = pa.Table.from_batches(result) + assert len(result_table) == 0 + + +def test_index_file_structure(tmp_path, geoarrow_data): + """Test that BKD tree index creates expected file structure.""" + table, _ = geoarrow_data + dataset_path = tmp_path / "geo_dataset" + + ds = lance.write_dataset(table, dataset_path) + ds.create_scalar_index(column="start_location", index_type="BKDTREE") + + index_dir = dataset_path / "_indices" + assert index_dir.exists() + + # Check for index subdirectories + index_subdirs = [d for d in index_dir.iterdir() if d.is_dir()] + assert len(index_subdirs) > 0 + + # Check that index files exist and have content + for subdir in index_subdirs: + files = list(subdir.glob("*")) + assert len(files) > 0 + + # Verify files have content + for f in files: + if f.is_file(): + assert f.stat().st_size > 0 diff --git a/python/python/tests/test_optimize.py b/python/python/tests/test_optimize.py index 8bf12db91ae..e334111280e 100644 --- a/python/python/tests/test_optimize.py +++ b/python/python/tests/test_optimize.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright The Lance Authors +# SPDX-FileCopyrightTexPiot: Copyright The Lance Authors import pickle import random import re diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 8c8f086ff9b..b4e3aa9816b 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1601,6 +1601,7 @@ impl Dataset { "BLOOMFILTER" => IndexType::BloomFilter, "LABEL_LIST" => IndexType::LabelList, "INVERTED" | "FTS" => IndexType::Inverted, + "BKDTREE" => IndexType::BkdTree, "IVF_FLAT" | "IVF_PQ" | "IVF_SQ" | "IVF_RQ" | "IVF_HNSW_FLAT" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, _ => { @@ -1702,6 +1703,10 @@ impl Dataset { } Box::new(params) } + "BKDTREE" => Box::new(ScalarIndexParams { + index_type: "bkdtree".to_string(), + params: None, + }), _ => { let column_type = match self.ds.schema().field(columns[0]) { Some(f) => f.data_type().clone(), diff --git a/rust/lance-datafusion/src/udf.rs b/rust/lance-datafusion/src/udf.rs index 24366077c66..1449cf32ef7 100644 --- a/rust/lance-datafusion/src/udf.rs +++ b/rust/lance-datafusion/src/udf.rs @@ -26,6 +26,131 @@ pub fn register_functions(ctx: &SessionContext) { ctx.register_udf(json::json_get_bool_udf()); ctx.register_udf(json::json_array_contains_udf()); ctx.register_udf(json::json_array_length_udf()); + + // GEO functions + ctx.register_udf(ST_INTERSECTS_UDF.clone()); + ctx.register_udf(ST_WITHIN_UDF.clone()); + ctx.register_udf(BBOX_UDF.clone()); +} + +static ST_WITHIN_UDF: LazyLock = LazyLock::new(st_within); +static BBOX_UDF: LazyLock = LazyLock::new(bbox); +static ST_INTERSECTS_UDF: LazyLock = LazyLock::new(st_intersects); + +fn st_intersects() -> ScalarUDF { + let function = Arc::new(make_scalar_function( + |_args: &[ArrayRef]| { + // Throw an error indicating that a spatial index is required + Err(datafusion::error::DataFusionError::Execution( + "st_intersects requires a spatial index. Please create a spatial index on the geometry column using dataset.create_scalar_index(column='your_column', index_type='GEO') before running spatial queries.".to_string(), + )) + }, + vec![], + )); + + create_udf( + "st_intersects", + vec![ + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), + ] + .into(), + ), + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ] + .into(), + ), + ], // GeoArrow Point struct, GeoArrow Box struct + DataType::Boolean, + Volatility::Immutable, + function, + ) +} + +fn st_within() -> ScalarUDF { + let function = Arc::new(make_scalar_function( + |_args: &[ArrayRef]| { + // Throw an error indicating that a spatial index is required + Err(datafusion::error::DataFusionError::Execution( + "st_within requires a spatial index. Please create a spatial index on the geometry column using dataset.create_scalar_index(column='your_column', index_type='RTREE') before running spatial queries.".to_string(), + )) + }, + vec![], + )); + + create_udf( + "st_within", + vec![ + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("x", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("y", DataType::Float64, false)), + ] + .into(), + ), + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ] + .into(), + ), + ], // GeoArrow Point struct, GeoArrow Box struct + DataType::Boolean, + Volatility::Immutable, + function, + ) +} + +/// BBOX function that creates a bounding box from four numeric arguments. +/// This function is used internally by spatial queries and doesn't perform actual computation. +/// It's intercepted by Lance's geo query parser for index optimization. +/// +/// Usage in SQL: +/// ```sql +/// SELECT * FROM table WHERE ST_Intersects(geometry_column, BBOX(-180, -90, 180, 90)) +/// ``` +fn bbox() -> ScalarUDF { + let function = Arc::new(make_scalar_function( + |_args: &[ArrayRef]| { + // This UDF should never be called because BBOX functions are intercepted by the query parser + // If this executes, it means no spatial index exists + Err(datafusion::error::DataFusionError::Execution( + "BBOX function requires a spatial index. Please create a spatial index on the geometry column using dataset.create_scalar_index(column='your_column', index_type='RTREE') before running spatial queries.".to_string(), + )) + }, + vec![], + )); + + create_udf( + "bbox", + vec![ + DataType::Float64, + DataType::Float64, + DataType::Float64, + DataType::Float64, + ], // min_x, min_y, max_x, max_y + DataType::Struct( + vec![ + Arc::new(arrow_schema::Field::new("xmin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymin", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("xmax", DataType::Float64, false)), + Arc::new(arrow_schema::Field::new("ymax", DataType::Float64, false)), + ] + .into(), + ), // Returns a GeoArrow Box struct + Volatility::Immutable, + function, + ) } /// This method checks whether a string contains all specified tokens. The tokens are separated by diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 605e40aa792..6a2220fdbef 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -98,9 +98,6 @@ pprof.workspace = true # docs.rs uses an older version of Ubuntu that does not have the necessary protoc version features = ["protoc"] -[profile.bench] -debug = true - [[bench]] name = "find_partitions" harness = false @@ -149,5 +146,9 @@ harness = false name = "rq" harness = false +[[bench]] +name = "geoindex" +harness = false + [lints] workspace = true diff --git a/rust/lance-index/benches/geoindex.rs b/rust/lance-index/benches/geoindex.rs new file mode 100644 index 00000000000..a311fab4736 --- /dev/null +++ b/rust/lance-index/benches/geoindex.rs @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmarks for geo index vs brute force scanning + +use arrow_schema::{DataType, Field}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use lance_core::cache::LanceCache; +use lance_core::utils::tempfile::TempObjDir; +use lance_io::object_store::ObjectStore; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; + +use lance_index::metrics::{MetricsCollector, NoOpMetricsCollector}; +use lance_index::scalar::geo::geoindex::{GeoIndex, GeoIndexBuilder, GeoIndexBuilderParams}; +use lance_index::scalar::lance_format::LanceIndexStore; +use lance_index::scalar::{GeoQuery, ScalarIndex}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +struct LeafCounter { + leaves_visited: AtomicUsize, +} + +impl LeafCounter { + fn new() -> Self { + Self { + leaves_visited: AtomicUsize::new(0), + } + } + + fn get_count(&self) -> usize { + self.leaves_visited.load(Ordering::Relaxed) + } + + fn reset(&self) { + self.leaves_visited.store(0, Ordering::Relaxed); + } +} + +impl MetricsCollector for LeafCounter { + fn record_parts_loaded(&self, num_parts: usize) { + self.leaves_visited.fetch_add(num_parts, Ordering::Relaxed); + } + + fn record_index_loads(&self, _num_indexes: usize) {} + fn record_comparisons(&self, _num_comparisons: usize) {} +} + +fn create_test_points(num_points: usize) -> Vec<(f64, f64, u64)> { + let mut rng = StdRng::seed_from_u64(42); + let mut points = Vec::with_capacity(num_points); + + for i in 0..num_points { + let x = rng.random_range(0.0..1000.0); + let y = rng.random_range(0.0..1000.0); + points.push((x, y, i as u64)); + } + + points +} + +async fn create_geo_index( + points: Vec<(f64, f64, u64)>, +) -> (Arc, Arc, TempObjDir) { + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let params = GeoIndexBuilderParams { + max_points_per_leaf: 1000, + batches_per_file: 10, + }; + + let mut builder = + GeoIndexBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + builder.points = points; + builder.write_index(store.as_ref()).await.unwrap(); + + let index = GeoIndex::load(store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + (index, store, tmpdir) +} + +fn bench_geo_index_intersects_pruning(c: &mut Criterion) { + let mut group = c.benchmark_group("geo_intersects_pruning_vs_scan_all"); + group.sample_size(10); // Reduce samples for faster benchmarks + group.measurement_time(std::time::Duration::from_secs(60)); + + // Compares BKD tree spatial pruning vs scanning all leaves for intersects queries + // Both approaches do lazy loading - we're measuring pruning efficiency + + // Test with different dataset sizes + for num_points in [10_000, 100_000, 1_000_000] { + let points = create_test_points(num_points); + + // Create index (do this once outside the benchmark) + let rt = tokio::runtime::Runtime::new().unwrap(); + let (index, _store, _tmpdir) = rt.block_on(create_geo_index(points.clone())); + + // Generate random queries + let mut rng = StdRng::seed_from_u64(42); + let queries: Vec<[f64; 4]> = (0..10) + .map(|_| { + let width = rng.random_range(10.0..50.0); + let height = rng.random_range(10.0..50.0); + let min_x = rng.random_range(0.0..(1000.0 - width)); + let min_y = rng.random_range(0.0..(1000.0 - height)); + [min_x, min_y, min_x + width, min_y + height] + }) + .collect(); + + // Benchmark: scan all leaves (no spatial pruning) + let scan_all_counter = Arc::new(LeafCounter::new()); + group.bench_with_input( + BenchmarkId::new("scan_all_leaves", num_points), + &(&index, &queries, scan_all_counter.clone()), + |b, (index, queries, counter)| { + let rt = tokio::runtime::Runtime::new().unwrap(); + b.iter(|| { + counter.reset(); + rt.block_on(async { + for query_bbox in queries.iter() { + let _result = index + .search_all_leaves(*query_bbox, counter.as_ref()) + .await + .unwrap(); + } + }); + }); + }, + ); + + // Benchmark: BKD tree with spatial pruning + let pruned_counter = Arc::new(LeafCounter::new()); + group.bench_with_input( + BenchmarkId::new("with_pruning", num_points), + &(&index, &queries, pruned_counter.clone()), + |b, (index, queries, counter)| { + let rt = tokio::runtime::Runtime::new().unwrap(); + b.iter(|| { + counter.reset(); + rt.block_on(async { + for query_bbox in queries.iter() { + let query = GeoQuery::Intersects( + query_bbox[0], + query_bbox[1], + query_bbox[2], + query_bbox[3], + ); + let _result = index.search(&query, counter.as_ref()).await.unwrap(); + } + }); + }); + }, + ); + + // Print statistics with speedup estimate + let scan_avg = scan_all_counter.get_count() as f64 / queries.len() as f64; + let pruned_avg = pruned_counter.get_count() as f64 / queries.len() as f64; + + // Quick timing for speedup calculation + let rt = tokio::runtime::Runtime::new().unwrap(); + let scan_start = std::time::Instant::now(); + for query_bbox in queries.iter() { + rt.block_on(async { + let _ = index + .search_all_leaves(*query_bbox, &NoOpMetricsCollector) + .await; + }); + } + let scan_time = scan_start.elapsed().as_secs_f64() / queries.len() as f64; + + let prune_start = std::time::Instant::now(); + for query_bbox in queries.iter() { + rt.block_on(async { + let query = GeoQuery::Intersects( + query_bbox[0], + query_bbox[1], + query_bbox[2], + query_bbox[3], + ); + let _ = index.search(&query, &NoOpMetricsCollector).await; + }); + } + let prune_time = prune_start.elapsed().as_secs_f64() / queries.len() as f64; + + println!( + "\n {} points: {} leaves | scan_all: {:.1} leaves/query ({:.1}ms) | with_pruning: {:.1} leaves/query ({:.1}ms) | {:.1}x faster, {:.1}% reduction", + num_points, + index.num_leaves(), + scan_avg, + scan_time * 1000.0, + pruned_avg, + prune_time * 1000.0, + scan_time / prune_time, + 100.0 * (1.0 - pruned_avg / scan_avg) + ); + } + + group.finish(); +} + +fn bench_geo_intersects_query_size(c: &mut Criterion) { + let mut group = c.benchmark_group("geo_intersects_query_size"); + group.sample_size(10); + + // Fixed dataset size, varying intersects query box sizes + let points = create_test_points(1_000_000); + let rt = tokio::runtime::Runtime::new().unwrap(); + let (index, _store, _tmpdir) = rt.block_on(create_geo_index(points)); + + let mut rng = StdRng::seed_from_u64(42); + + // Test different query sizes + for query_size in [1.0, 10.0, 50.0, 100.0, 200.0] { + let queries: Vec<[f64; 4]> = (0..10) + .map(|_| { + let min_x = rng.random_range(0.0..(1000.0 - query_size)); + let min_y = rng.random_range(0.0..(1000.0 - query_size)); + [min_x, min_y, min_x + query_size, min_y + query_size] + }) + .collect(); + + let leaf_counter = Arc::new(LeafCounter::new()); + group.bench_with_input( + BenchmarkId::new("query_size", format!("{}x{}", query_size, query_size)), + &(&index, &queries, leaf_counter.clone()), + |b, (index, queries, counter)| { + let rt = tokio::runtime::Runtime::new().unwrap(); + b.iter(|| { + counter.reset(); + rt.block_on(async { + for query_bbox in queries.iter() { + let query = GeoQuery::Intersects( + query_bbox[0], + query_bbox[1], + query_bbox[2], + query_bbox[3], + ); + let _result = index.search(&query, counter.as_ref()).await.unwrap(); + } + }); + }); + }, + ); + + let avg_leaves = leaf_counter.get_count() as f64 / queries.len() as f64; + println!( + " Query {}x{}: avg {:.1} leaves visited (out of {} total, {:.1}% selectivity)", + query_size, + query_size, + avg_leaves, + index.num_leaves(), + 100.0 * avg_leaves / index.num_leaves() as f64 + ); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_geo_index_intersects_pruning, + bench_geo_intersects_query_size +); +criterion_main!(benches); diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 26184cd47ff..cec4cafd244 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -108,6 +108,8 @@ pub enum IndexType { BloomFilter = 9, // Bloom filter + BkdTree = 10, // BKD Tree + // 100+ and up for vector index. /// Flat vector index. Vector = 100, // Legacy vector index, alias to IvfPq @@ -130,6 +132,7 @@ impl std::fmt::Display for IndexType { Self::NGram => write!(f, "NGram"), Self::FragmentReuse => write!(f, "FragmentReuse"), Self::MemWal => write!(f, "MemWal"), + Self::BkdTree => write!(f, "BkdTree"), Self::ZoneMap => write!(f, "ZoneMap"), Self::BloomFilter => write!(f, "BloomFilter"), Self::Vector | Self::IvfPq => write!(f, "IVF_PQ"), @@ -156,6 +159,7 @@ impl TryFrom for IndexType { v if v == Self::Inverted as i32 => Ok(Self::Inverted), v if v == Self::FragmentReuse as i32 => Ok(Self::FragmentReuse), v if v == Self::MemWal as i32 => Ok(Self::MemWal), + v if v == Self::BkdTree as i32 => Ok(Self::BkdTree), v if v == Self::ZoneMap as i32 => Ok(Self::ZoneMap), v if v == Self::BloomFilter as i32 => Ok(Self::BloomFilter), v if v == Self::Vector as i32 => Ok(Self::Vector), @@ -214,6 +218,7 @@ impl IndexType { | Self::NGram | Self::ZoneMap | Self::BloomFilter + | Self::BkdTree ) } @@ -252,7 +257,7 @@ impl IndexType { Self::MemWal => 0, Self::ZoneMap => 0, Self::BloomFilter => 0, - + Self::BkdTree => 0, // for now all vector indices are built by the same builder, // so they share the same version. Self::Vector diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 69b5ee35cf0..9700635cfa4 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -33,6 +33,7 @@ pub mod bloomfilter; pub mod btree; pub mod expression; pub mod flat; +pub mod geo; pub mod inverted; pub mod json; pub mod label_list; @@ -61,6 +62,7 @@ pub enum BuiltinIndexType { ZoneMap, BloomFilter, Inverted, + BkdTree, } impl BuiltinIndexType { @@ -73,6 +75,7 @@ impl BuiltinIndexType { Self::ZoneMap => "zonemap", Self::Inverted => "inverted", Self::BloomFilter => "bloomfilter", + Self::BkdTree => "bkdtree", } } } @@ -89,6 +92,7 @@ impl TryFrom for BuiltinIndexType { IndexType::ZoneMap => Ok(Self::ZoneMap), IndexType::Inverted => Ok(Self::Inverted), IndexType::BloomFilter => Ok(Self::BloomFilter), + IndexType::BkdTree => Ok(Self::BkdTree), _ => Err(Error::Index { message: "Invalid index type".to_string(), location: location!(), @@ -587,6 +591,48 @@ pub enum TokenQuery { TokensContains(String), } +/// A query that a GeoIndex can satisfy +#[derive(Debug, Clone, PartialEq)] +pub enum GeoQuery { + /// Retrieve all row ids where the geometry intersects with the given bounding box + /// Format: (min_x, min_y, max_x, max_y) + Intersects(f64, f64, f64, f64), +} + +impl AnyQuery for GeoQuery { + fn as_any(&self) -> &dyn Any { + self + } + + fn format(&self, col: &str) -> String { + match self { + Self::Intersects(min_x, min_y, max_x, max_y) => { + format!( + "st_intersects({}, bbox({}, {}, {}, {}))", + col, min_x, min_y, max_x, max_y + ) + } + } + } + + fn to_expr(&self, _col: String) -> Expr { + match self { + Self::Intersects(_min_x, _min_y, _max_x, _max_y) => { + // For now, return a placeholder expression + // This would need to be a proper st_intersects UDF call + Expr::Literal(ScalarValue::Boolean(Some(true)), None) + } + } + } + + fn dyn_eq(&self, other: &dyn AnyQuery) -> bool { + match other.as_any().downcast_ref::() { + Some(o) => self == o, + None => false, + } + } +} + /// A query that a BloomFilter index can satisfy /// /// This is a subset of SargableQuery that only includes operations that bloom filters diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 047dddeaa93..26c000b1ec8 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -18,8 +18,8 @@ use datafusion_expr::{ }; use super::{ - AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, - SearchResult, TextQuery, TokenQuery, + AnyQuery, BloomFilterQuery, GeoQuery, LabelListQuery, MetricsCollector, SargableQuery, + ScalarIndex, SearchResult, TextQuery, TokenQuery, }; use futures::join; use lance_core::{utils::mask::RowIdMask, Error, Result}; @@ -665,6 +665,122 @@ impl ScalarQueryParser for FtsQueryParser { } } +/// A parser for geo indices that handles spatial queries +#[derive(Debug, Clone)] +pub struct GeoQueryParser { + index_name: String, +} + +impl GeoQueryParser { + pub fn new(index_name: String) -> Self { + Self { index_name } + } + + /// Extract bounding box coordinates from a bbox() function call + /// Expected format: bbox(min_x, min_y, max_x, max_y) + fn extract_bbox(&self, expr: &Expr) -> Option<(f64, f64, f64, f64)> { + match expr { + Expr::ScalarFunction(ScalarFunction { func, args }) => { + if func.name() == "bbox" && args.len() == 4 { + // Extract the four coordinates + let min_x = maybe_scalar(&args[0], &DataType::Float64)?; + let min_y = maybe_scalar(&args[1], &DataType::Float64)?; + let max_x = maybe_scalar(&args[2], &DataType::Float64)?; + let max_y = maybe_scalar(&args[3], &DataType::Float64)?; + + // Convert to f64 + let min_x = match min_x { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + let min_y = match min_y { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + let max_x = match max_x { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + let max_y = match max_y { + ScalarValue::Float64(Some(v)) => v, + ScalarValue::Float32(Some(v)) => v as f64, + ScalarValue::Int64(Some(v)) => v as f64, + ScalarValue::Int32(Some(v)) => v as f64, + _ => return None, + }; + + return Some((min_x, min_y, max_x, max_y)); + } + None + } + _ => None, + } + } +} + +impl ScalarQueryParser for GeoQueryParser { + fn visit_between( + &self, + _: &str, + _: &Bound, + _: &Bound, + ) -> Option { + None + } + + fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { + None + } + + fn visit_is_bool(&self, _: &str, _: bool) -> Option { + None + } + + fn visit_is_null(&self, _: &str) -> Option { + None + } + + fn visit_comparison( + &self, + _: &str, + _: &ScalarValue, + _: &Operator, + ) -> Option { + None + } + + fn visit_scalar_function( + &self, + column: &str, + _data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + // Handle st_intersects(geometry_column, bbox(...)) + if func.name() == "st_intersects" && args.len() == 2 { + // The second argument should be a bbox() call + if let Some((min_x, min_y, max_x, max_y)) = self.extract_bbox(&args[1]) { + let query = GeoQuery::Intersects(min_x, min_y, max_x, max_y); + return Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + Arc::new(query), + )); + } + } + None + } +} + impl IndexedExpression { /// Create an expression that only does refine fn refine_only(refine_expr: Expr) -> Self { diff --git a/rust/lance-index/src/scalar/geo/bkd.rs b/rust/lance-index/src/scalar/geo/bkd.rs new file mode 100644 index 00000000000..7b0ebcd026a --- /dev/null +++ b/rust/lance-index/src/scalar/geo/bkd.rs @@ -0,0 +1,882 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! BKD Tree (Block K-Dimensional Tree) implementation +//! +//! A BKD tree is a spatial index structure for efficiently indexing and querying +//! multi-dimensional points. It's similar to a KD-tree but optimized for disk storage +//! by grouping multiple points into leaf blocks. +//! +//! ## Algorithm +//! +//! Based on Lucene's BKD tree implementation: +//! - Recursively splits points by alternating dimensions (X, Y) +//! - Splits at median to create balanced tree +//! - Groups points into leaves of configurable size +//! - Stores tree structure separately from leaf data for lazy loading + +use arrow_array::cast::AsArray; +use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use deepsize::DeepSizeOf; +use lance_core::{Result, ROW_ID}; +use snafu::location; +use std::sync::Arc; + +// Schema field names +const NODE_ID: &str = "node_id"; +const MIN_X: &str = "min_x"; +const MIN_Y: &str = "min_y"; +const MAX_X: &str = "max_x"; +const MAX_Y: &str = "max_y"; +const SPLIT_DIM: &str = "split_dim"; +const SPLIT_VALUE: &str = "split_value"; +const LEFT_CHILD: &str = "left_child"; +const RIGHT_CHILD: &str = "right_child"; +const FILE_ID: &str = "file_id"; +const ROW_OFFSET: &str = "row_offset"; +const NUM_ROWS: &str = "num_rows"; + +/// Schema for inner node metadata +pub fn inner_node_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new(NODE_ID, DataType::UInt32, false), + Field::new(MIN_X, DataType::Float64, false), + Field::new(MIN_Y, DataType::Float64, false), + Field::new(MAX_X, DataType::Float64, false), + Field::new(MAX_Y, DataType::Float64, false), + Field::new(SPLIT_DIM, DataType::UInt8, false), + Field::new(SPLIT_VALUE, DataType::Float64, false), + Field::new(LEFT_CHILD, DataType::UInt32, false), + Field::new(RIGHT_CHILD, DataType::UInt32, false), + ])) +} + +/// Schema for leaf node metadata +pub fn leaf_node_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new(NODE_ID, DataType::UInt32, false), + Field::new(MIN_X, DataType::Float64, false), + Field::new(MIN_Y, DataType::Float64, false), + Field::new(MAX_X, DataType::Float64, false), + Field::new(MAX_Y, DataType::Float64, false), + Field::new(FILE_ID, DataType::UInt32, false), + Field::new(ROW_OFFSET, DataType::UInt64, false), + Field::new(NUM_ROWS, DataType::UInt64, false), + ])) +} + +/// BKD Tree node - either an inner node (with children) or a leaf node (with data location) +#[derive(Debug, Clone, DeepSizeOf)] +pub enum BKDNode { + Inner(BKDInnerNode), + Leaf(BKDLeafNode), +} + +impl BKDNode { + pub fn bounds(&self) -> [f64; 4] { + match self { + BKDNode::Inner(inner) => inner.bounds, + BKDNode::Leaf(leaf) => leaf.bounds, + } + } + + pub fn is_leaf(&self) -> bool { + matches!(self, BKDNode::Leaf(_)) + } + + pub fn as_inner(&self) -> Option<&BKDInnerNode> { + match self { + BKDNode::Inner(inner) => Some(inner), + _ => None, + } + } + + pub fn as_leaf(&self) -> Option<&BKDLeafNode> { + match self { + BKDNode::Leaf(leaf) => Some(leaf), + _ => None, + } + } +} + +/// Inner node in BKD tree - contains split information and child pointers +#[derive(Debug, Clone, DeepSizeOf)] +pub struct BKDInnerNode { + /// Bounding box: [min_x, min_y, max_x, max_y] + pub bounds: [f64; 4], + /// Split dimension: 0=X, 1=Y + pub split_dim: u8, + /// Split value along split_dim + pub split_value: f64, + /// Left child node index + pub left_child: u32, + /// Right child node index + pub right_child: u32, +} + +/// Leaf node in BKD tree - contains location of actual point data +#[derive(Debug, Clone, DeepSizeOf)] +pub struct BKDLeafNode { + /// Bounding box: [min_x, min_y, max_x, max_y] + pub bounds: [f64; 4], + /// Which leaf group file this leaf is in + /// Corresponds to leaf_group_{file_id}.lance + pub file_id: u32, + /// Row offset within the leaf group file + pub row_offset: u64, + /// Number of rows in this leaf batch + pub num_rows: u64, +} + +/// In-memory BKD tree structure for efficient spatial queries +#[derive(Debug, DeepSizeOf)] +pub struct BKDTreeLookup { + pub nodes: Vec, + pub root_id: u32, + pub num_leaves: u32, +} + +impl BKDTreeLookup { + pub fn new(nodes: Vec, root_id: u32, num_leaves: u32) -> Self { + Self { + nodes, + root_id, + num_leaves, + } + } + + /// Find all leaf nodes that intersect with the query bounding box + /// Returns references to the intersecting leaf nodes + pub fn find_intersecting_leaves(&self, query_bbox: [f64; 4]) -> Result> { + let mut leaves = Vec::new(); + let mut stack = vec![self.root_id]; + let mut _nodes_visited = 0; + + while let Some(node_id) = stack.pop() { + if node_id as usize >= self.nodes.len() { + continue; + } + + let node = &self.nodes[node_id as usize]; + _nodes_visited += 1; + + // Check if node's bounding box intersects with query bbox + let intersects = bboxes_intersect(&node.bounds(), &query_bbox); + if !intersects { + continue; + } + + match node { + BKDNode::Leaf(leaf) => { + // Leaf node - add to results + leaves.push(leaf); + } + BKDNode::Inner(inner) => { + // Inner node - traverse children + stack.push(inner.left_child); + stack.push(inner.right_child); + } + } + } + + Ok(leaves) + } + + /// Deserialize from separate inner and leaf RecordBatches + /// Both batches include node_id field to preserve tree structure + pub fn from_record_batches(inner_batch: RecordBatch, leaf_batch: RecordBatch) -> Result { + if inner_batch.num_rows() == 0 && leaf_batch.num_rows() == 0 { + return Ok(Self::new(vec![], 0, 0)); + } + + // Helper to get column by name + let get_col = |batch: &RecordBatch, name: &str| -> Result { + batch + .schema() + .column_with_name(name) + .map(|(idx, _)| idx) + .ok_or_else(|| lance_core::Error::Internal { + message: format!("Missing column '{}' in BKD tree batch", name), + location: location!(), + }) + }; + + // Determine total number of nodes (max node_id + 1) + let max_node_id = { + let mut max_id = 0u32; + + if inner_batch.num_rows() > 0 { + let col_idx = get_col(&inner_batch, NODE_ID)?; + let node_ids = inner_batch + .column(col_idx) + .as_primitive::(); + for i in 0..inner_batch.num_rows() { + max_id = max_id.max(node_ids.value(i)); + } + } + + if leaf_batch.num_rows() > 0 { + let col_idx = get_col(&leaf_batch, NODE_ID)?; + let node_ids = leaf_batch + .column(col_idx) + .as_primitive::(); + for i in 0..leaf_batch.num_rows() { + max_id = max_id.max(node_ids.value(i)); + } + } + + max_id + }; + + // Create sparse array of nodes (filled with dummy data initially) + let mut nodes = vec![ + BKDNode::Leaf(BKDLeafNode { + bounds: [0.0, 0.0, 0.0, 0.0], + file_id: 0, + row_offset: 0, + num_rows: 0, + }); + (max_node_id + 1) as usize + ]; + + let mut num_leaves = 0; + + // Fill in inner nodes + if inner_batch.num_rows() > 0 { + let node_ids = inner_batch + .column(get_col(&inner_batch, NODE_ID)?) + .as_primitive::(); + let min_x = inner_batch + .column(get_col(&inner_batch, MIN_X)?) + .as_primitive::(); + let min_y = inner_batch + .column(get_col(&inner_batch, MIN_Y)?) + .as_primitive::(); + let max_x = inner_batch + .column(get_col(&inner_batch, MAX_X)?) + .as_primitive::(); + let max_y = inner_batch + .column(get_col(&inner_batch, MAX_Y)?) + .as_primitive::(); + let split_dim = inner_batch + .column(get_col(&inner_batch, SPLIT_DIM)?) + .as_primitive::(); + let split_value = inner_batch + .column(get_col(&inner_batch, SPLIT_VALUE)?) + .as_primitive::(); + let left_child = inner_batch + .column(get_col(&inner_batch, LEFT_CHILD)?) + .as_primitive::(); + let right_child = inner_batch + .column(get_col(&inner_batch, RIGHT_CHILD)?) + .as_primitive::(); + + for i in 0..inner_batch.num_rows() { + let node_id = node_ids.value(i) as usize; + nodes[node_id] = BKDNode::Inner(BKDInnerNode { + bounds: [ + min_x.value(i), + min_y.value(i), + max_x.value(i), + max_y.value(i), + ], + split_dim: split_dim.value(i), + split_value: split_value.value(i), + left_child: left_child.value(i), + right_child: right_child.value(i), + }); + } + } + + // Fill in leaf nodes + if leaf_batch.num_rows() > 0 { + let node_ids = leaf_batch + .column(get_col(&leaf_batch, NODE_ID)?) + .as_primitive::(); + let min_x = leaf_batch + .column(get_col(&leaf_batch, MIN_X)?) + .as_primitive::(); + let min_y = leaf_batch + .column(get_col(&leaf_batch, MIN_Y)?) + .as_primitive::(); + let max_x = leaf_batch + .column(get_col(&leaf_batch, MAX_X)?) + .as_primitive::(); + let max_y = leaf_batch + .column(get_col(&leaf_batch, MAX_Y)?) + .as_primitive::(); + let file_id = leaf_batch + .column(get_col(&leaf_batch, FILE_ID)?) + .as_primitive::(); + let row_offset = leaf_batch + .column(get_col(&leaf_batch, ROW_OFFSET)?) + .as_primitive::(); + let num_rows = leaf_batch + .column(get_col(&leaf_batch, NUM_ROWS)?) + .as_primitive::(); + + for i in 0..leaf_batch.num_rows() { + let node_id = node_ids.value(i) as usize; + nodes[node_id] = BKDNode::Leaf(BKDLeafNode { + bounds: [ + min_x.value(i), + min_y.value(i), + max_x.value(i), + max_y.value(i), + ], + file_id: file_id.value(i), + row_offset: row_offset.value(i), + num_rows: num_rows.value(i), + }); + num_leaves += 1; + } + } + + Ok(Self::new(nodes, 0, num_leaves)) + } +} + +/// Check if two bounding boxes intersect +pub fn bboxes_intersect(bbox1: &[f64; 4], bbox2: &[f64; 4]) -> bool { + // bbox format: [min_x, min_y, max_x, max_y] + !(bbox1[2] < bbox2[0] || bbox1[0] > bbox2[2] || bbox1[3] < bbox2[1] || bbox1[1] > bbox2[3]) +} + +/// Check if a point is within a bounding box +pub fn point_in_bbox(x: f64, y: f64, bbox: &[f64; 4]) -> bool { + x >= bbox[0] && x <= bbox[2] && y >= bbox[1] && y <= bbox[3] +} + +/// BKD Tree builder following Lucene's bulk-loading algorithm +pub struct BKDTreeBuilder { + leaf_size: usize, +} + +impl BKDTreeBuilder { + pub fn new(leaf_size: usize) -> Self { + Self { leaf_size } + } + + /// Build a BKD tree from points + /// Returns (tree_nodes, leaf_batches) + pub fn build( + &self, + points: &mut [(f64, f64, u64)], + batches_per_file: u32, + ) -> Result<(Vec, Vec)> { + if points.is_empty() { + return Ok((vec![], vec![])); + } + + let mut leaf_counter = 0u32; + let mut all_nodes = Vec::new(); + let mut all_leaf_batches = Vec::new(); + + self.build_recursive( + points, + 0, // depth + &mut leaf_counter, + batches_per_file, + &mut all_nodes, + &mut all_leaf_batches, + )?; + + // Post-process: Update leaf nodes with correct row offsets + // Leaves are created in order during build_recursive, so we can + // calculate cumulative offsets based on actual batch sizes + let mut current_file_id = 0u32; + let mut row_offset_in_file = 0u64; + let mut batches_in_current_file = 0u32; + let mut leaf_idx = 0; + + for node in all_nodes.iter_mut() { + if let BKDNode::Leaf(leaf) = node { + if leaf_idx < all_leaf_batches.len() { + let batch_num_rows = all_leaf_batches[leaf_idx].num_rows() as u64; + + // Check if we need to move to next file + if batches_in_current_file >= batches_per_file && batches_per_file > 0 { + current_file_id += 1; + row_offset_in_file = 0; + batches_in_current_file = 0; + } + + // Update leaf with correct metadata + leaf.file_id = current_file_id; + leaf.row_offset = row_offset_in_file; + leaf.num_rows = batch_num_rows; + + // Advance for next leaf + row_offset_in_file += batch_num_rows; + batches_in_current_file += 1; + leaf_idx += 1; + } + } + } + + Ok((all_nodes, all_leaf_batches)) + } + + /// Recursively build BKD tree following Lucene's algorithm + fn build_recursive( + &self, + points: &mut [(f64, f64, u64)], + depth: u32, + leaf_counter: &mut u32, + batches_per_file: u32, + all_nodes: &mut Vec, + all_leaf_batches: &mut Vec, + ) -> Result { + // Base case: create leaf node + if points.len() <= self.leaf_size { + let node_id = all_nodes.len() as u32; + let _leaf_id = *leaf_counter; + *leaf_counter += 1; + + // Calculate bounding box for this leaf + let (min_x, min_y, max_x, max_y) = calculate_bounds(points); + + // Create leaf batch + let leaf_batch = create_leaf_batch(points)?; + let num_rows = leaf_batch.num_rows() as u64; + all_leaf_batches.push(leaf_batch); + + // Create leaf node (file_id, row_offset will be set in post-processing) + all_nodes.push(BKDNode::Leaf(BKDLeafNode { + bounds: [min_x, min_y, max_x, max_y], + file_id: 0, // Will be updated in post-processing + row_offset: 0, // Will be updated in post-processing + num_rows, + })); + + return Ok(node_id); + } + + // Recursive case: split and build subtrees + let split_dim = (depth % 2) as u8; // Alternate between X (0) and Y (1) + + // TODO: Replace with radix selection for O(n) median finding (Lucene's approach) + // Current: O(n log n) sorting at each level = O(n log² n) total + // Target: O(n) radix select at each level = O(n log n) total + // See: https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java + + // Sort points by the split dimension + if split_dim == 0 { + points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + } else { + points.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + } + + // Split at median + let mid = points.len() / 2; + let split_value = if split_dim == 0 { + points[mid].0 + } else { + points[mid].1 + }; + + // Calculate bounds for this node (before splitting the slice) + let (min_x, min_y, max_x, max_y) = calculate_bounds(points); + + // Reserve space for this inner node (placeholder - we'll update it after building children) + let node_id = all_nodes.len() as u32; + all_nodes.push(BKDNode::Inner(BKDInnerNode { + bounds: [min_x, min_y, max_x, max_y], + split_dim, + split_value, + left_child: 0, // Placeholder + right_child: 0, // Placeholder + })); + + // Recursively build left and right subtrees + let (left_points, right_points) = points.split_at_mut(mid); + + let left_child_id = self.build_recursive( + left_points, + depth + 1, + leaf_counter, + batches_per_file, + all_nodes, + all_leaf_batches, + )?; + + let right_child_id = self.build_recursive( + right_points, + depth + 1, + leaf_counter, + batches_per_file, + all_nodes, + all_leaf_batches, + )?; + + // Update the inner node with actual child pointers + if let BKDNode::Inner(inner) = &mut all_nodes[node_id as usize] { + inner.left_child = left_child_id; + inner.right_child = right_child_id; + } + + Ok(node_id) + } +} + +/// Calculate bounding box for a set of points +fn calculate_bounds(points: &[(f64, f64, u64)]) -> (f64, f64, f64, f64) { + let mut min_x = f64::MAX; + let mut min_y = f64::MAX; + let mut max_x = f64::MIN; + let mut max_y = f64::MIN; + + for (x, y, _) in points { + min_x = min_x.min(*x); + min_y = min_y.min(*y); + max_x = max_x.max(*x); + max_y = max_y.max(*y); + } + + (min_x, min_y, max_x, max_y) +} + +/// Create a leaf batch from points +fn create_leaf_batch(points: &[(f64, f64, u64)]) -> Result { + let x_vals: Vec = points.iter().map(|(x, _, _)| *x).collect(); + let y_vals: Vec = points.iter().map(|(_, y, _)| *y).collect(); + let row_ids: Vec = points.iter().map(|(_, _, r)| *r).collect(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Float64Array::from(x_vals)) as ArrayRef, + Arc::new(Float64Array::from(y_vals)) as ArrayRef, + Arc::new(UInt64Array::from(row_ids)) as ArrayRef, + ], + )?; + + Ok(batch) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Float64Array, UInt32Array, UInt64Array, UInt8Array}; + + /// Test serialization and deserialization of a simple BKD tree + #[test] + fn test_serialize_deserialize_simple_tree() { + // Create a simple tree: 1 inner node with 2 leaf children + let nodes = vec![ + BKDNode::Inner(BKDInnerNode { + bounds: [-10.0, -10.0, 10.0, 10.0], + split_dim: 0, + split_value: 0.0, + left_child: 1, + right_child: 2, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [-10.0, -10.0, 0.0, 10.0], + file_id: 0, + row_offset: 0, + num_rows: 100, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [0.0, -10.0, 10.0, 10.0], + file_id: 0, + row_offset: 100, + num_rows: 100, + }), + ]; + + // Separate inner and leaf nodes + let inner_nodes: Vec<(u32, &BKDNode)> = nodes + .iter() + .enumerate() + .filter(|(_, n)| matches!(n, BKDNode::Inner(_))) + .map(|(i, n)| (i as u32, n)) + .collect(); + + let leaf_nodes: Vec<(u32, &BKDNode)> = nodes + .iter() + .enumerate() + .filter(|(_, n)| matches!(n, BKDNode::Leaf(_))) + .map(|(i, n)| (i as u32, n)) + .collect(); + + // Serialize inner nodes + let mut inner_node_ids = Vec::new(); + let mut inner_min_x = Vec::new(); + let mut inner_min_y = Vec::new(); + let mut inner_max_x = Vec::new(); + let mut inner_max_y = Vec::new(); + let mut inner_split_dim = Vec::new(); + let mut inner_split_value = Vec::new(); + let mut inner_left_child = Vec::new(); + let mut inner_right_child = Vec::new(); + + for (idx, node) in &inner_nodes { + if let BKDNode::Inner(inner) = node { + inner_node_ids.push(*idx); + inner_min_x.push(inner.bounds[0]); + inner_min_y.push(inner.bounds[1]); + inner_max_x.push(inner.bounds[2]); + inner_max_y.push(inner.bounds[3]); + inner_split_dim.push(inner.split_dim); + inner_split_value.push(inner.split_value); + inner_left_child.push(inner.left_child); + inner_right_child.push(inner.right_child); + } + } + + let inner_batch = RecordBatch::try_new( + inner_node_schema(), + vec![ + Arc::new(UInt32Array::from(inner_node_ids)), + Arc::new(Float64Array::from(inner_min_x)), + Arc::new(Float64Array::from(inner_min_y)), + Arc::new(Float64Array::from(inner_max_x)), + Arc::new(Float64Array::from(inner_max_y)), + Arc::new(UInt8Array::from(inner_split_dim)), + Arc::new(Float64Array::from(inner_split_value)), + Arc::new(UInt32Array::from(inner_left_child)), + Arc::new(UInt32Array::from(inner_right_child)), + ], + ) + .unwrap(); + + // Serialize leaf nodes + let mut leaf_node_ids = Vec::new(); + let mut leaf_min_x = Vec::new(); + let mut leaf_min_y = Vec::new(); + let mut leaf_max_x = Vec::new(); + let mut leaf_max_y = Vec::new(); + let mut leaf_file_ids = Vec::new(); + let mut leaf_row_offsets = Vec::new(); + let mut leaf_num_rows = Vec::new(); + + for (idx, node) in &leaf_nodes { + if let BKDNode::Leaf(leaf) = node { + leaf_node_ids.push(*idx); + leaf_min_x.push(leaf.bounds[0]); + leaf_min_y.push(leaf.bounds[1]); + leaf_max_x.push(leaf.bounds[2]); + leaf_max_y.push(leaf.bounds[3]); + leaf_file_ids.push(leaf.file_id); + leaf_row_offsets.push(leaf.row_offset); + leaf_num_rows.push(leaf.num_rows); + } + } + + let leaf_batch = RecordBatch::try_new( + leaf_node_schema(), + vec![ + Arc::new(UInt32Array::from(leaf_node_ids)), + Arc::new(Float64Array::from(leaf_min_x)), + Arc::new(Float64Array::from(leaf_min_y)), + Arc::new(Float64Array::from(leaf_max_x)), + Arc::new(Float64Array::from(leaf_max_y)), + Arc::new(UInt32Array::from(leaf_file_ids)), + Arc::new(UInt64Array::from(leaf_row_offsets)), + Arc::new(UInt64Array::from(leaf_num_rows)), + ], + ) + .unwrap(); + + // Deserialize + let tree = BKDTreeLookup::from_record_batches(inner_batch, leaf_batch).unwrap(); + + // Verify structure + assert_eq!(tree.nodes.len(), 3); + assert_eq!(tree.root_id, 0); + assert_eq!(tree.num_leaves, 2); + + // Verify root (inner node) + match &tree.nodes[0] { + BKDNode::Inner(inner) => { + assert_eq!(inner.bounds, [-10.0, -10.0, 10.0, 10.0]); + assert_eq!(inner.split_dim, 0); + assert_eq!(inner.split_value, 0.0); + assert_eq!(inner.left_child, 1); + assert_eq!(inner.right_child, 2); + } + _ => panic!("Expected inner node at index 0"), + } + + // Verify left leaf + match &tree.nodes[1] { + BKDNode::Leaf(leaf) => { + assert_eq!(leaf.bounds, [-10.0, -10.0, 0.0, 10.0]); + assert_eq!(leaf.file_id, 0); + assert_eq!(leaf.row_offset, 0); + assert_eq!(leaf.num_rows, 100); + } + _ => panic!("Expected leaf node at index 1"), + } + + // Verify right leaf + match &tree.nodes[2] { + BKDNode::Leaf(leaf) => { + assert_eq!(leaf.bounds, [0.0, -10.0, 10.0, 10.0]); + assert_eq!(leaf.file_id, 0); + assert_eq!(leaf.row_offset, 100); + assert_eq!(leaf.num_rows, 100); + } + _ => panic!("Expected leaf node at index 2"), + } + } + + /// Test serialization of empty tree + #[test] + fn test_serialize_deserialize_empty_tree() { + let inner_batch = RecordBatch::new_empty(inner_node_schema()); + let leaf_batch = RecordBatch::new_empty(leaf_node_schema()); + + let tree = BKDTreeLookup::from_record_batches(inner_batch, leaf_batch).unwrap(); + + assert_eq!(tree.nodes.len(), 0); + assert_eq!(tree.num_leaves, 0); + } + + /// Test bbox intersection logic + #[test] + fn test_bboxes_intersect() { + // Overlapping boxes + assert!(bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[5.0, 5.0, 15.0, 15.0] + )); + + // Touching boxes + assert!(bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[10.0, 0.0, 20.0, 10.0] + )); + + // Fully contained + assert!(bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[2.0, 2.0, 8.0, 8.0] + )); + + // Non-overlapping (left) + assert!(!bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[-20.0, 0.0, -10.0, 10.0] + )); + + // Non-overlapping (above) + assert!(!bboxes_intersect( + &[0.0, 0.0, 10.0, 10.0], + &[0.0, 20.0, 10.0, 30.0] + )); + } + + /// Test point in bbox logic + #[test] + fn test_point_in_bbox() { + let bbox = [0.0, 0.0, 10.0, 10.0]; + + // Inside + assert!(point_in_bbox(5.0, 5.0, &bbox)); + + // On boundary + assert!(point_in_bbox(0.0, 0.0, &bbox)); + assert!(point_in_bbox(10.0, 10.0, &bbox)); + + // Outside + assert!(!point_in_bbox(-1.0, 5.0, &bbox)); + assert!(!point_in_bbox(11.0, 5.0, &bbox)); + assert!(!point_in_bbox(5.0, -1.0, &bbox)); + assert!(!point_in_bbox(5.0, 11.0, &bbox)); + } + + /// Test find_intersecting_leaves with simple tree + #[test] + fn test_find_intersecting_leaves() { + // Create a simple tree: 1 inner node with 2 leaf children + let nodes = vec![ + BKDNode::Inner(BKDInnerNode { + bounds: [-10.0, -10.0, 10.0, 10.0], + split_dim: 0, + split_value: 0.0, + left_child: 1, + right_child: 2, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [-10.0, -10.0, 0.0, 10.0], + file_id: 0, + row_offset: 0, + num_rows: 100, + }), + BKDNode::Leaf(BKDLeafNode { + bounds: [0.0, -10.0, 10.0, 10.0], + file_id: 0, + row_offset: 100, + num_rows: 100, + }), + ]; + + let tree = BKDTreeLookup::new(nodes, 0, 2); + + // Query that intersects only left leaf + let leaves = tree + .find_intersecting_leaves([-10.0, -10.0, -5.0, 10.0]) + .unwrap(); + assert_eq!(leaves.len(), 1); + assert_eq!(leaves[0].file_id, 0); + assert_eq!(leaves[0].row_offset, 0); + + // Query that intersects only right leaf + let leaves = tree + .find_intersecting_leaves([5.0, -10.0, 10.0, 10.0]) + .unwrap(); + assert_eq!(leaves.len(), 1); + assert_eq!(leaves[0].file_id, 0); + assert_eq!(leaves[0].row_offset, 100); + + // Query that intersects both leaves + let leaves = tree + .find_intersecting_leaves([-5.0, -10.0, 5.0, 10.0]) + .unwrap(); + assert_eq!(leaves.len(), 2); + + // Query that intersects no leaves + let leaves = tree + .find_intersecting_leaves([20.0, 20.0, 30.0, 30.0]) + .unwrap(); + assert_eq!(leaves.len(), 0); + } + + /// Test tree building with small dataset + #[test] + fn test_build_tree_small() { + let mut points = vec![ + (-5.0, -5.0, 0), + (-4.0, -4.0, 1), + (4.0, 4.0, 2), + (5.0, 5.0, 3), + ]; + + let builder = BKDTreeBuilder::new(2); // leaf_size = 2 + let (nodes, batches) = builder.build(&mut points, 5).unwrap(); + + // Should have: 1 root + 2 leaves = 3 nodes + assert_eq!(nodes.len(), 3); + assert_eq!(batches.len(), 2); + + // Verify root is inner node + assert!(matches!(nodes[0], BKDNode::Inner(_))); + + // Verify we have 2 leaf nodes + let leaf_count = nodes.iter().filter(|n| n.is_leaf()).count(); + assert_eq!(leaf_count, 2); + + // Verify each batch has correct size + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 2); + } +} diff --git a/rust/lance-index/src/scalar/geo/bkdtree.rs b/rust/lance-index/src/scalar/geo/bkdtree.rs new file mode 100644 index 00000000000..6396a5128b5 --- /dev/null +++ b/rust/lance-index/src/scalar/geo/bkdtree.rs @@ -0,0 +1,2463 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! BKD Tree Index +//! +//! BKD (Bounding K-Dimensional) trees are spatial database structures for efficient spatial queries. +//! They enable efficient filtering by location-based predicates. +//! +//! ## Requirements +//! +//! BKD tree indices can only be created on fields with GeoArrow metadata. The field must: +//! - Be a Struct data type +//! - Have `ARROW:extension:name` metadata starting with `geoarrow.` (e.g., `geoarrow.point`, `geoarrow.polygon`) +//! +//! ## Query Support +//! +//! BKD tree indices are "inexact" filters - they can definitively exclude regions but may include +//! false positives that require rechecking. +//! + +use super::bkd::{self, point_in_bbox, BKDLeafNode, BKDNode, BKDTreeBuilder, BKDTreeLookup}; +use crate::pbold; +use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser}; +use crate::scalar::registry::{ + ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest, +}; +use crate::scalar::{BuiltinIndexType, CreatedIndex, GeoQuery, ScalarIndexParams, UpdateCriteria}; +use crate::Any; +use futures::TryStreamExt; +use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; +use lance_core::{ROW_ADDR, ROW_ID}; +use serde::{Deserialize, Serialize}; + +use arrow_array::cast::AsArray; +use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array, UInt8Array}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::execution::SendableRecordBatchStream; +use std::{collections::HashMap, sync::Arc}; + +use crate::scalar::FragReuseIndex; +use crate::scalar::{AnyQuery, IndexStore, MetricsCollector, ScalarIndex, SearchResult}; +use crate::vector::VectorIndex; +use crate::{Index, IndexType}; +use async_trait::async_trait; +use deepsize::DeepSizeOf; +use lance_core::Result; +use lance_core::{utils::mask::RowIdTreeMap, Error}; +use roaring::RoaringBitmap; +use snafu::location; + +const BKD_TREE_INNER_FILENAME: &str = "bkd_tree_inner.lance"; +const BKD_TREE_LEAF_FILENAME: &str = "bkd_tree_leaf.lance"; +const LEAF_GROUP_PREFIX: &str = "leaf_group_"; +const DEFAULT_BATCHES_PER_LEAF_FILE: u32 = 5; // Default number of leaf batches per file +const BKD_TREE_VERSION: u32 = 0; +const MAX_POINTS_PER_LEAF_META_KEY: &str = "max_points_per_leaf"; +const BATCHES_PER_FILE_META_KEY: &str = "batches_per_file"; +const DEFAULT_MAX_POINTS_PER_LEAF: u32 = 100; // for test + // const DEFAULT_MAX_POINTS_PER_LEAF: u32 = 1024; // for production + +/// Get the file name for a leaf group +fn leaf_group_filename(group_id: u32) -> String { + format!("{}{}.lance", LEAF_GROUP_PREFIX, group_id) +} + +/// Cache key for BKD leaf nodes +#[derive(Debug, Clone)] +struct BKDLeafKey { + leaf_id: u32, +} + +impl CacheKey for BKDLeafKey { + type ValueType = CachedLeafData; + + fn key(&self) -> std::borrow::Cow<'_, str> { + format!("bkd-leaf-{}", self.leaf_id).into() + } +} + +/// Cached leaf data +#[derive(Debug, Clone)] +struct CachedLeafData(RecordBatch); + +impl DeepSizeOf for CachedLeafData { + fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { + // Approximate size of RecordBatch + self.0.get_array_memory_size() + } +} + +impl CachedLeafData { + fn new(batch: RecordBatch) -> Self { + Self(batch) + } + + fn into_inner(self) -> RecordBatch { + self.0 + } +} + +/// BKD tree index for spatial queries +pub struct BkdTree { + data_type: DataType, + store: Arc, + fri: Option>, + index_cache: WeakLanceCache, + bkd_tree: Arc, + max_points_per_leaf: u32, +} + +impl std::fmt::Debug for BkdTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BkdTree") + .field("data_type", &self.data_type) + .field("store", &self.store) + .field("fri", &self.fri) + .field("index_cache", &self.index_cache) + .field("bkd_tree", &self.bkd_tree) + .field("max_points_per_leaf", &self.max_points_per_leaf) + .finish() + } +} + +impl DeepSizeOf for BkdTree { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.bkd_tree.deep_size_of_children(context) + self.store.deep_size_of_children(context) + } +} + +impl BkdTree { + /// Load the geo index from storage + pub async fn load( + store: Arc, + fri: Option>, + index_cache: &LanceCache, + ) -> Result> + where + Self: Sized, + { + // Load inner nodes + let inner_file = store.open_index_file(BKD_TREE_INNER_FILENAME).await?; + let inner_data = inner_file + .read_range(0..inner_file.num_rows(), None) + .await?; + + // Load leaf metadata + let leaf_file = store.open_index_file(BKD_TREE_LEAF_FILENAME).await?; + let leaf_data = leaf_file.read_range(0..leaf_file.num_rows(), None).await?; + + // Deserialize tree structure from both files + let bkd_tree = BKDTreeLookup::from_record_batches(inner_data, leaf_data)?; + + // Extract metadata from inner file + let schema = inner_file.schema(); + let max_points_per_leaf = schema + .metadata + .get(MAX_POINTS_PER_LEAF_META_KEY) + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_MAX_POINTS_PER_LEAF); + + // Get data type from schema + let data_type = schema.fields[0].data_type().clone(); + + Ok(Arc::new(Self { + data_type, + store, + fri, + index_cache: WeakLanceCache::from(index_cache), + bkd_tree: Arc::new(bkd_tree), + max_points_per_leaf, + })) + } + + /// Load a specific leaf from storage + async fn load_leaf( + &self, + leaf: &BKDLeafNode, + metrics: &dyn MetricsCollector, + ) -> Result { + let file_id = leaf.file_id; + let row_offset = leaf.row_offset; + let num_rows = leaf.num_rows; + + // Use (file_id, row_offset) as cache key + // Combine file_id and row_offset into a single u32 (file_id should be small) + let cache_key = BKDLeafKey { + leaf_id: file_id * 100_000 + (row_offset as u32), + }; + let store = self.store.clone(); + + let cached = self + .index_cache + .get_or_insert_with_key(cache_key, move || async move { + metrics.record_part_load(); + + let filename = leaf_group_filename(file_id); + + // Open the leaf group file and read the specific row range + let reader = store.open_index_file(&filename).await?; + let batch = reader + .read_range(row_offset as usize..(row_offset + num_rows) as usize, None) + .await?; + + Ok(CachedLeafData::new(batch)) + }) + .await?; + + Ok(cached.as_ref().clone().into_inner()) + } + + /// Get the number of leaves in the BKD tree (useful for benchmarking) + pub fn num_leaves(&self) -> usize { + self.bkd_tree.num_leaves as usize + } + + /// Search all leaves without using BKD tree pruning (only useful for benchmarking) + pub async fn search_all_leaves( + &self, + query_bbox: [f64; 4], + metrics: &dyn MetricsCollector, + ) -> Result { + let mut all_row_ids = RowIdTreeMap::new(); + + // Iterate through all nodes and search every leaf + for node in self.bkd_tree.nodes.iter() { + if let BKDNode::Leaf(leaf) = node { + let leaf_row_ids = self.search_leaf(leaf, query_bbox, metrics).await?; + let row_ids: Option> = leaf_row_ids + .row_ids() + .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); + if let Some(row_ids) = row_ids { + all_row_ids.extend(row_ids); + } + } + } + + Ok(all_row_ids) + } + + /// Search a specific leaf for points within the query bbox + async fn search_leaf( + &self, + leaf: &BKDLeafNode, + query_bbox: [f64; 4], + metrics: &dyn MetricsCollector, + ) -> Result { + let leaf_data = self.load_leaf(leaf, metrics).await?; + + // Filter points within this leaf using iterators + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + let row_ids: Vec = x_array + .iter() + .zip(y_array.iter()) + .zip(row_id_array.iter()) + .filter_map( + |((x_opt, y_opt), row_id_opt)| match (x_opt, y_opt, row_id_opt) { + (Some(x), Some(y), Some(row_id)) if point_in_bbox(x, y, &query_bbox) => { + Some(row_id) + } + _ => None, + }, + ) + .collect(); + + let mut row_id_map = RowIdTreeMap::new(); + row_id_map.extend(row_ids); + Ok(row_id_map) + } +} + +#[async_trait] +impl Index for BkdTree { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_index(self: Arc) -> Arc { + self + } + + fn as_vector_index(self: Arc) -> Result> { + Err(Error::InvalidInput { + source: "BkdTree is not a vector index".into(), + location: location!(), + }) + } + + async fn prewarm(&self) -> Result<()> { + // No-op: geo index uses lazy loading + Ok(()) + } + + fn statistics(&self) -> Result { + Ok(serde_json::json!({ + "type": "bkdtree", + })) + } + + fn index_type(&self) -> IndexType { + IndexType::BkdTree + } + + async fn calculate_included_frags(&self) -> Result { + let frag_ids = RoaringBitmap::new(); + Ok(frag_ids) + } +} + +#[async_trait] +impl ScalarIndex for BkdTree { + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { + let geo_query = + query + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Geo index only supports GeoQuery".into(), + location: location!(), + })?; + + match geo_query { + GeoQuery::Intersects(min_x, min_y, max_x, max_y) => { + let query_bbox = [*min_x, *min_y, *max_x, *max_y]; + + // Step 1: Find intersecting leaves using in-memory tree traversal + let leaves = self.bkd_tree.find_intersecting_leaves(query_bbox)?; + + // Step 2: Lazy-load and filter each leaf + let mut all_row_ids = RowIdTreeMap::new(); + + for leaf_node in &leaves { + let leaf_row_ids = self.search_leaf(leaf_node, query_bbox, metrics).await?; + // Collect row IDs from the leaf and add them to the result set + let row_ids: Option> = leaf_row_ids + .row_ids() + .map(|iter| iter.map(|row_addr| u64::from(row_addr)).collect()); + if let Some(row_ids) = row_ids { + all_row_ids.extend(row_ids); + } + } + + // We return Exact because we already filtered points in search_leaf + Ok(SearchResult::Exact(all_row_ids)) + } + } + } + + fn can_remap(&self) -> bool { + false + } + + /// Remap the row ids, creating a new remapped version of this index in `dest_store` + async fn remap( + &self, + _mapping: &HashMap>, + _dest_store: &dyn IndexStore, + ) -> Result { + Err(Error::InvalidInput { + source: "BkdTree does not support remap".into(), + location: location!(), + }) + } + + /// Add the new data , creating an updated version of the index in `dest_store` + async fn update( + &self, + _new_data: SendableRecordBatchStream, + _dest_store: &dyn IndexStore, + ) -> Result { + Err(Error::InvalidInput { + source: "BkdTree does not support update".into(), + location: location!(), + }) + } + + fn update_criteria(&self) -> UpdateCriteria { + UpdateCriteria::only_new_data( + TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), + ) + } + + fn derive_index_params(&self) -> Result { + let params = serde_json::to_value(BkdTreeBuilderParams::default())?; + Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::BkdTree).with_params(¶ms)) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BkdTreeBuilderParams { + #[serde(default = "default_max_points_per_leaf")] + pub max_points_per_leaf: u32, + #[serde(default = "default_batches_per_file")] + pub batches_per_file: u32, +} + +fn default_max_points_per_leaf() -> u32 { + DEFAULT_MAX_POINTS_PER_LEAF +} + +fn default_batches_per_file() -> u32 { + DEFAULT_BATCHES_PER_LEAF_FILE +} + +impl Default for BkdTreeBuilderParams { + fn default() -> Self { + Self { + max_points_per_leaf: default_max_points_per_leaf(), + batches_per_file: default_batches_per_file(), + } + } +} + +impl BkdTreeBuilderParams { + pub fn new() -> Self { + Self::default() + } + + pub fn with_max_points_per_leaf(mut self, max_points_per_leaf: u32) -> Self { + self.max_points_per_leaf = max_points_per_leaf; + self + } +} + +// A builder for geo index +pub struct BkdTreeBuilder { + options: BkdTreeBuilderParams, + _items_type: DataType, + // Accumulated points: (x, y, row_id) + pub points: Vec<(f64, f64, u64)>, +} + +impl BkdTreeBuilder { + pub fn try_new(options: BkdTreeBuilderParams, items_type: DataType) -> Result { + Ok(Self { + options, + _items_type: items_type, + points: Vec::new(), + }) + } + + pub async fn train(&mut self, batches_source: SendableRecordBatchStream) -> Result<()> { + assert!(batches_source.schema().field_with_name(ROW_ADDR).is_ok()); + + let mut batches_source = batches_source; + + while let Some(batch) = batches_source.try_next().await? { + // Extract GeoArrow point coordinates + let geom_array = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Expected Struct array for GeoArrow data".into(), + location: location!(), + })?; + + let x_array = geom_array + .column(0) + .as_primitive::(); + let y_array = geom_array + .column(1) + .as_primitive::(); + let row_ids = batch + .column_by_name(ROW_ADDR) + .unwrap() + .as_primitive::(); + + for i in 0..batch.num_rows() { + self.points + .push((x_array.value(i), y_array.value(i), row_ids.value(i))); + } + } + + log::debug!("Accumulated {} points for BKD tree", self.points.len()); + + Ok(()) + } + + pub async fn write_index(mut self, index_store: &dyn IndexStore) -> Result<()> { + if self.points.is_empty() { + return Ok(()); + } + + // Validate coordinates (reject NaN and Infinity like Lucene does) + for (x, y, row_id) in &self.points { + if x.is_nan() || y.is_nan() { + return Err(Error::InvalidInput { + source: format!("Cannot index NaN coordinates (row_id={})", row_id).into(), + location: location!(), + }); + } + if x.is_infinite() || y.is_infinite() { + return Err(Error::InvalidInput { + source: format!("Cannot index Infinite coordinates (row_id={})", row_id).into(), + location: location!(), + }); + } + } + + // Build BKD tree + let (tree_nodes, leaf_batches) = self.build_bkd_tree()?; + + // Write tree structure to separate inner and leaf files + let (inner_batch, leaf_metadata_batch) = self.serialize_tree_nodes(&tree_nodes)?; + + // Write inner nodes + let mut inner_file = index_store + .new_index_file(BKD_TREE_INNER_FILENAME, inner_batch.schema()) + .await?; + inner_file.write_record_batch(inner_batch).await?; + inner_file + .finish_with_metadata(HashMap::from([ + ( + MAX_POINTS_PER_LEAF_META_KEY.to_string(), + self.options.max_points_per_leaf.to_string(), + ), + ( + BATCHES_PER_FILE_META_KEY.to_string(), + self.options.batches_per_file.to_string(), + ), + ])) + .await?; + + // Write leaf metadata + let mut leaf_meta_file = index_store + .new_index_file(BKD_TREE_LEAF_FILENAME, leaf_metadata_batch.schema()) + .await?; + leaf_meta_file + .write_record_batch(leaf_metadata_batch) + .await?; + leaf_meta_file.finish().await?; + + // Write actual leaf data grouped into files (multiple batches per file) + let leaf_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + + let batches_per_file = self.options.batches_per_file; + let num_groups = (leaf_batches.len() as u32 + batches_per_file - 1) / batches_per_file; + + for group_id in 0..num_groups { + let start_idx = (group_id * batches_per_file) as usize; + let end_idx = + ((group_id + 1) * batches_per_file).min(leaf_batches.len() as u32) as usize; + let group_batches = &leaf_batches[start_idx..end_idx]; + + let filename = leaf_group_filename(group_id); + + let mut leaf_file = index_store + .new_index_file(&filename, leaf_schema.clone()) + .await?; + + for leaf_batch in group_batches.iter() { + leaf_file.write_record_batch(leaf_batch.clone()).await?; + } + + leaf_file.finish().await?; + } + + log::debug!("Wrote BKD tree with {} nodes", tree_nodes.len()); + + Ok(()) + } + + // Serialize tree nodes to separate inner and leaf RecordBatches + fn serialize_tree_nodes(&self, nodes: &[BKDNode]) -> Result<(RecordBatch, RecordBatch)> { + // Separate inner and leaf nodes with their indices + let mut inner_nodes = Vec::new(); + let mut leaf_nodes = Vec::new(); + + for (idx, node) in nodes.iter().enumerate() { + match node { + BKDNode::Inner(_) => inner_nodes.push((idx as u32, node)), + BKDNode::Leaf(_) => leaf_nodes.push((idx as u32, node)), + } + } + + // Serialize inner nodes + let inner_batch = Self::serialize_inner_nodes(&inner_nodes)?; + + // Serialize leaf nodes + let leaf_batch = Self::serialize_leaf_nodes(&leaf_nodes)?; + + Ok((inner_batch, leaf_batch)) + } + + fn serialize_inner_nodes(nodes: &[(u32, &BKDNode)]) -> Result { + let mut node_id_vals = Vec::with_capacity(nodes.len()); + let mut min_x_vals = Vec::with_capacity(nodes.len()); + let mut min_y_vals = Vec::with_capacity(nodes.len()); + let mut max_x_vals = Vec::with_capacity(nodes.len()); + let mut max_y_vals = Vec::with_capacity(nodes.len()); + let mut split_dim_vals = Vec::with_capacity(nodes.len()); + let mut split_value_vals = Vec::with_capacity(nodes.len()); + let mut left_child_vals = Vec::with_capacity(nodes.len()); + let mut right_child_vals = Vec::with_capacity(nodes.len()); + + for (idx, node) in nodes { + if let BKDNode::Inner(inner) = node { + node_id_vals.push(*idx); + min_x_vals.push(inner.bounds[0]); + min_y_vals.push(inner.bounds[1]); + max_x_vals.push(inner.bounds[2]); + max_y_vals.push(inner.bounds[3]); + split_dim_vals.push(inner.split_dim); + split_value_vals.push(inner.split_value); + left_child_vals.push(inner.left_child); + right_child_vals.push(inner.right_child); + } + } + + let schema = bkd::inner_node_schema(); + + let columns: Vec = vec![ + Arc::new(UInt32Array::from(node_id_vals)), + Arc::new(Float64Array::from(min_x_vals)), + Arc::new(Float64Array::from(min_y_vals)), + Arc::new(Float64Array::from(max_x_vals)), + Arc::new(Float64Array::from(max_y_vals)), + Arc::new(UInt8Array::from(split_dim_vals)), + Arc::new(Float64Array::from(split_value_vals)), + Arc::new(UInt32Array::from(left_child_vals)), + Arc::new(UInt32Array::from(right_child_vals)), + ]; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + fn serialize_leaf_nodes(nodes: &[(u32, &BKDNode)]) -> Result { + use arrow_array::UInt64Array; + + let mut node_id_vals = Vec::with_capacity(nodes.len()); + let mut min_x_vals = Vec::with_capacity(nodes.len()); + let mut min_y_vals = Vec::with_capacity(nodes.len()); + let mut max_x_vals = Vec::with_capacity(nodes.len()); + let mut max_y_vals = Vec::with_capacity(nodes.len()); + let mut file_id_vals = Vec::with_capacity(nodes.len()); + let mut row_offset_vals = Vec::with_capacity(nodes.len()); + let mut num_rows_vals = Vec::with_capacity(nodes.len()); + + for (idx, node) in nodes { + if let BKDNode::Leaf(leaf) = node { + node_id_vals.push(*idx); + min_x_vals.push(leaf.bounds[0]); + min_y_vals.push(leaf.bounds[1]); + max_x_vals.push(leaf.bounds[2]); + max_y_vals.push(leaf.bounds[3]); + file_id_vals.push(leaf.file_id); + row_offset_vals.push(leaf.row_offset); + num_rows_vals.push(leaf.num_rows); + } + } + + let schema = bkd::leaf_node_schema(); + + let columns: Vec = vec![ + Arc::new(UInt32Array::from(node_id_vals)), + Arc::new(Float64Array::from(min_x_vals)), + Arc::new(Float64Array::from(min_y_vals)), + Arc::new(Float64Array::from(max_x_vals)), + Arc::new(Float64Array::from(max_y_vals)), + Arc::new(UInt32Array::from(file_id_vals)), + Arc::new(UInt64Array::from(row_offset_vals)), + Arc::new(UInt64Array::from(num_rows_vals)), + ]; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + // Build BKD tree using the BKDTreeBuilder + fn build_bkd_tree(&mut self) -> Result<(Vec, Vec)> { + let builder = BKDTreeBuilder::new(self.options.max_points_per_leaf as usize); + builder.build(&mut self.points, self.options.batches_per_file) + } +} + +#[derive(Debug, Default)] +pub struct BkdTreePlugin; + +impl BkdTreePlugin { + async fn train_geo_index( + batches_source: SendableRecordBatchStream, + index_store: &dyn IndexStore, + options: Option, + ) -> Result<()> { + let value_type = batches_source.schema().field(0).data_type().clone(); + + let mut builder = BkdTreeBuilder::try_new(options.unwrap_or_default(), value_type)?; + + builder.train(batches_source).await?; + + builder.write_index(index_store).await?; + Ok(()) + } +} + +pub struct BkdTreeTrainingRequest { + pub params: BkdTreeBuilderParams, + pub criteria: TrainingCriteria, +} + +impl BkdTreeTrainingRequest { + pub fn new(params: BkdTreeBuilderParams) -> Self { + Self { + params, + criteria: TrainingCriteria::new(TrainingOrdering::Addresses).with_row_addr(), + } + } +} + +impl TrainingRequest for BkdTreeTrainingRequest { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn criteria(&self) -> &TrainingCriteria { + &self.criteria + } +} + +#[async_trait] +impl ScalarIndexPlugin for BkdTreePlugin { + fn new_training_request( + &self, + params: &str, + field: &Field, + ) -> Result> { + // Check that the field is a Struct type + if !matches!(field.data_type(), DataType::Struct(_)) { + return Err(Error::InvalidInput { + source: "A geo index can only be created on a Struct field.".into(), + location: location!(), + }); + } + + // Check for GeoArrow metadata + let is_geoarrow = field + .metadata() + .get("ARROW:extension:name") + .map(|name| name.starts_with("geoarrow.")) + .unwrap_or(false); + + if !is_geoarrow { + return Err(Error::InvalidInput { + source: format!( + "Geo index requires GeoArrow metadata on field '{}'. \ + The field must have 'ARROW:extension:name' metadata starting with 'geoarrow.'", + field.name() + ) + .into(), + location: location!(), + }); + } + + let params = serde_json::from_str::(params)?; + + Ok(Box::new(BkdTreeTrainingRequest::new(params))) + } + + fn provides_exact_answer(&self) -> bool { + true // We do exact point-in-bbox filtering in search_leaf + } + + fn version(&self) -> u32 { + BKD_TREE_VERSION + } + + fn new_query_parser( + &self, + index_name: String, + _index_details: &prost_types::Any, + ) -> Option> { + Some(Box::new(GeoQueryParser::new(index_name))) + } + + async fn train_index( + &self, + data: SendableRecordBatchStream, + index_store: &dyn IndexStore, + request: Box, + fragment_ids: Option>, + ) -> Result { + if fragment_ids.is_some() { + return Err(Error::InvalidInput { + source: "Geo index does not support fragment training".into(), + location: location!(), + }); + } + + let request = (request as Box) + .downcast::() + .map_err(|_| Error::InvalidInput { + source: "must provide training request created by new_training_request".into(), + location: location!(), + })?; + Self::train_geo_index(data, index_store, Some(request.params)).await?; + Ok(CreatedIndex { + index_details: prost_types::Any::from_msg(&crate::pb::BkdTreeIndexDetails::default()) + .unwrap(), + index_version: BKD_TREE_VERSION, + }) + } + + async fn load_index( + &self, + index_store: Arc, + _index_details: &prost_types::Any, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result> { + Ok(BkdTree::load(index_store, frag_reuse_index, cache).await? as Arc) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow_schema::{DataType, Field}; + use lance_core::cache::LanceCache; + use lance_core::utils::tempfile::TempObjDir; + use lance_io::object_store::ObjectStore; + + use crate::scalar::lance_format::LanceIndexStore; + + fn create_test_store() -> Arc { + let tmpdir = TempObjDir::default(); + + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + return test_store; + } + + /// Validates the BKD tree structure recursively + async fn validate_bkd_tree(index: &BkdTree) -> Result<()> { + use crate::metrics::NoOpMetricsCollector; + + let tree = &index.bkd_tree; + + // Verify root exists + if tree.nodes.is_empty() { + return Err(Error::InvalidInput { + source: "BKD tree has no nodes".into(), + location: location!(), + }); + } + + let root_id = tree.root_id as usize; + if root_id >= tree.nodes.len() { + return Err(Error::InvalidInput { + source: format!( + "Root node id {} out of bounds (tree has {} nodes)", + root_id, + tree.nodes.len() + ) + .into(), + location: location!(), + }); + } + + // Count leaves and validate structure + let mut visited = vec![false; tree.nodes.len()]; + let mut leaf_count = 0u32; + + // Start with None for expected_split_dim - root can use any dimension + // (typically starts with dimension 0, but we don't enforce it) + let metrics = NoOpMetricsCollector {}; + validate_node_recursive( + index, + tree, + root_id as u32, + &mut visited, + &mut leaf_count, + None, + None, + &metrics, + ) + .await?; + + // Verify all leaves were counted + if leaf_count != tree.num_leaves { + return Err(Error::InvalidInput { + source: format!( + "Leaf count mismatch: found {} leaves but tree claims {}", + leaf_count, tree.num_leaves + ) + .into(), + location: location!(), + }); + } + + // Check for orphaned nodes (nodes not reachable from root) + for (idx, was_visited) in visited.iter().enumerate() { + if !was_visited { + return Err(Error::InvalidInput { + source: format!("Node {} is orphaned (not reachable from root)", idx).into(), + location: location!(), + }); + } + } + + Ok(()) + } + + /// Helper function to recursively validate a node and its descendants + async fn validate_node_recursive( + index: &BkdTree, + tree: &BKDTreeLookup, + node_id: u32, + visited: &mut Vec, + leaf_count: &mut u32, + parent_bounds: Option<[f64; 4]>, + _parent_split: Option<(u8, f64, bool)>, // (split_dim, split_value, is_left_child) from parent + metrics: &dyn MetricsCollector, + ) -> Result<()> { + let node_idx = node_id as usize; + + // Check node exists + if node_idx >= tree.nodes.len() { + return Err(Error::InvalidInput { + source: format!( + "Node id {} out of bounds (tree has {} nodes)", + node_id, + tree.nodes.len() + ) + .into(), + location: location!(), + }); + } + + // Mark as visited + if visited[node_idx] { + return Err(Error::InvalidInput { + source: format!("Node {} visited multiple times (cycle detected)", node_id).into(), + location: location!(), + }); + } + visited[node_idx] = true; + + let node = &tree.nodes[node_idx]; + let bounds = node.bounds(); + + // Validate bounds are well-formed + if bounds[0] > bounds[2] || bounds[1] > bounds[3] { + return Err(Error::InvalidInput { + source: format!( + "Node {} has invalid bounds: [{}, {}, {}, {}]", + node_id, bounds[0], bounds[1], bounds[2], bounds[3] + ) + .into(), + location: location!(), + }); + } + + // Verify child bounds are within parent bounds + if let Some(parent_bounds) = parent_bounds { + if bounds[0] < parent_bounds[0] + || bounds[1] < parent_bounds[1] + || bounds[2] > parent_bounds[2] + || bounds[3] > parent_bounds[3] + { + return Err(Error::InvalidInput { + source: format!( + "Node {} bounds [{}, {}, {}, {}] exceed parent bounds [{}, {}, {}, {}]", + node_id, + bounds[0], + bounds[1], + bounds[2], + bounds[3], + parent_bounds[0], + parent_bounds[1], + parent_bounds[2], + parent_bounds[3] + ) + .into(), + location: location!(), + }); + } + } + + match node { + BKDNode::Inner(inner) => { + // Validate split dimension + if inner.split_dim > 1 { + return Err(Error::InvalidInput { + source: format!( + "Node {} has invalid split_dim: {} (must be 0 or 1)", + node_id, inner.split_dim + ) + .into(), + location: location!(), + }); + } + + // Validate split value is within bounds + let min_val = if inner.split_dim == 0 { + bounds[0] + } else { + bounds[1] + }; + let max_val = if inner.split_dim == 0 { + bounds[2] + } else { + bounds[3] + }; + if inner.split_value < min_val || inner.split_value > max_val { + return Err(Error::InvalidInput { + source: format!( + "Node {} split_value {} is outside dimension bounds [{}, {}]", + node_id, inner.split_value, min_val, max_val + ) + .into(), + location: location!(), + }); + } + + // Validate that children are properly sorted by split dimension + // If points are sorted, left_max <= split_value <= right_min + // Which means left_max <= right_min (no overlap) + let left_node = &tree.nodes[inner.left_child as usize]; + let right_node = &tree.nodes[inner.right_child as usize]; + + let left_bounds = left_node.bounds(); + let right_bounds = right_node.bounds(); + + let (left_max, right_min) = if inner.split_dim == 0 { + (left_bounds[2], right_bounds[0]) // max_x of left, min_x of right + } else { + (left_bounds[3], right_bounds[1]) // max_y of left, min_y of right + }; + + // Simple check: left_max should not be greater than right_min + // If it is, points were not properly sorted before splitting! + if left_max > right_min { + return Err(Error::InvalidInput { + source: format!( + "Node {} (split_dim={}, split_value={}): left child max {}={} > right child min {}={}. \ + Points were not properly sorted before splitting!", + node_id, inner.split_dim, inner.split_value, + if inner.split_dim == 0 { "x" } else { "y" }, + left_max, + if inner.split_dim == 0 { "x" } else { "y" }, + right_min + ).into(), + location: location!(), + }); + } + + // Recursively validate children + Box::pin(validate_node_recursive( + index, + tree, + inner.left_child, + visited, + leaf_count, + Some(bounds), + Some((inner.split_dim, inner.split_value, true)), + metrics, + )) + .await?; + Box::pin(validate_node_recursive( + index, + tree, + inner.right_child, + visited, + leaf_count, + Some(bounds), + Some((inner.split_dim, inner.split_value, false)), + metrics, + )) + .await?; + } + bkd::BKDNode::Leaf(leaf) => { + // Validate leaf data + if leaf.num_rows == 0 { + return Err(Error::InvalidInput { + source: format!("Leaf node {} has zero rows", node_id).into(), + location: location!(), + }); + } + + // No extra validation needed for leaves - will be checked from parent + + *leaf_count += 1; + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_geo_intersects_with_custom_max_points_per_leaf() { + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + // Test with different max points per leaf + for max_points_per_leaf in [10, 50, 100, 200] { + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 500 RANDOM points (not sequential!) to catch sorting bugs + let mut rng = StdRng::seed_from_u64(42); + for i in 0..500 { + let x = rng.random_range(0.0..100.0); + let y = rng.random_range(0.0..100.0); + builder.points.push((x, y, i as u64)); + } + + // Write index + builder.write_index(test_store.as_ref()).await.unwrap(); + + // Load index and verify + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .expect("Failed to load BkdTree"); + + assert_eq!(index.max_points_per_leaf, max_points_per_leaf); + + // Validate tree structure + validate_bkd_tree(&index).await.expect(&format!( + "BKD tree validation failed for max_points_per_leaf={}", + max_points_per_leaf + )); + } + } + + #[tokio::test] + async fn test_geo_index_with_custom_batches_per_file() { + // Test with different batches_per_file configurations + for batches_per_file in [1, 3, 5, 10, 20] { + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 50, + batches_per_file, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create 500 points (BKD spatial partitioning determines actual leaf count) + for i in 0..500 { + let x = (i % 100) as f64; + let y = (i / 100) as f64; + builder.points.push((x, y, i as u64)); + } + + // Write index + builder.write_index(test_store.as_ref()).await.unwrap(); + + // Load index and verify + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .expect("Failed to load BkdTree"); + + // Validate tree structure + validate_bkd_tree(&index).await.expect(&format!( + "BKD tree validation failed for batches_per_file={}", + batches_per_file + )); + } + } + + #[tokio::test] + async fn test_geo_intersects_query_correctness_various_configs() { + use crate::metrics::NoOpMetricsCollector; + + // Test query correctness with different configurations + let configs = vec![ + (10, 1), // Small leaves, one per file + (50, 5), // Default-ish + (100, 10), // Larger leaves, many per file + ]; + + for (max_points_per_leaf, batches_per_file) in configs { + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf, + batches_per_file, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create a grid of points + for x in 0..20 { + for y in 0..20 { + let row_id = (x * 20 + y) as u64; + builder.points.push((x as f64, y as f64, row_id)); + } + } + + // Write and load index + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index).await.expect(&format!( + "BKD tree validation failed for max_points_per_leaf={}, batches_per_file={}", + max_points_per_leaf, batches_per_file + )); + + // Query: bbox [5, 5, 10, 10] should return points in that region + let query = GeoQuery::Intersects(5.0, 5.0, 10.0, 10.0); + + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + // Should find points (5,5) to (10,10) = 6x6 = 36 points + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!(row_ids.len().unwrap_or(0), 36, + "Expected 36 points for config max_points_per_leaf={}, batches_per_file={}, got {}", + max_points_per_leaf, batches_per_file, row_ids.len().unwrap_or(0)); + + // Verify correct row IDs + for x in 5..=10 { + for y in 5..=10 { + let expected_row_id = (x * 20 + y) as u64; + assert!( + row_ids.contains(expected_row_id), + "Missing row_id {} for point ({}, {})", + expected_row_id, + x, + y + ); + } + } + } + _ => panic!("Expected Exact search result"), + } + } + } + + #[tokio::test] + async fn test_geo_intersects_single_leaf() { + // Edge case: all points fit in single leaf + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 1000, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create only 50 points + for i in 0..50 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Should have only 1 leaf + assert_eq!(index.bkd_tree.num_leaves, 1); + assert_eq!(index.bkd_tree.nodes.len(), 1); + + // Validate tree structure + validate_bkd_tree(&index) + .await + .expect("BKD tree validation failed for single leaf test"); + + // Single leaf should be in file 0 at offset 0 + match &index.bkd_tree.nodes[0] { + BKDNode::Leaf(leaf) => { + assert_eq!(leaf.file_id, 0); + assert_eq!(leaf.row_offset, 0); + assert_eq!(leaf.num_rows, 50); + } + _ => panic!("Expected leaf node"), + } + } + + #[tokio::test] + async fn test_geo_intersects_many_small_leaves() { + // Stress test: many small leaves, test file grouping + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 5, // Very small leaves + batches_per_file: 3, // Few batches per file + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create 100 points (BKD spatial partitioning determines actual leaf count) + for i in 0..100 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // BKD tree creates more leaves due to spatial partitioning + assert_eq!(index.bkd_tree.num_leaves, 32); + + // Validate tree structure + validate_bkd_tree(&index) + .await + .expect("BKD tree validation failed for many small leaves test"); + + // 32 leaves / 3 batches_per_file = 11 files (ceil) + let file_ids: std::collections::HashSet = index + .bkd_tree + .nodes + .iter() + .filter_map(|n| n.as_leaf()) + .map(|l| l.file_id) + .collect(); + + assert_eq!(file_ids.len(), 11); + + // Verify row offsets are cumulative within each file + let leaves: Vec<_> = index + .bkd_tree + .nodes + .iter() + .filter_map(|n| n.as_leaf()) + .collect(); + + for file_id in 0..11 { + let leaves_in_file: Vec<_> = leaves.iter().filter(|l| l.file_id == file_id).collect(); + + let mut expected_offset = 0u64; + for leaf in leaves_in_file { + assert_eq!( + leaf.row_offset, expected_offset, + "Incorrect offset in file {}", + file_id + ); + expected_offset += leaf.num_rows; + } + } + } + + #[tokio::test] + async fn test_geo_index_data_integrity_after_serialization() { + // Verify every point written can be read back exactly + use crate::metrics::NoOpMetricsCollector; + + let configs = vec![ + (10, 1), // Small leaves, one per file + (50, 3), // Medium leaves, few per file + (100, 10), // Large leaves, many per file + ]; + + for (max_points_per_leaf, batches_per_file) in configs { + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf, + batches_per_file, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())) + .unwrap(); + + // Create test points with known values + let mut original_points = Vec::new(); + for i in 0..200 { + let x = (i % 50) as f64 + 0.123; // Non-integer to test precision + let y = (i / 50) as f64 + 0.456; + let row_id = i as u64; + original_points.push((x, y, row_id)); + builder.points.push((x, y, row_id)); + } + + // Write index + builder.write_index(test_store.as_ref()).await.unwrap(); + + // Load index + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Read back all points by querying each leaf + let metrics = NoOpMetricsCollector {}; + let mut recovered_points = std::collections::HashMap::new(); + + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + // Load the leaf data + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + + // Extract points from this leaf + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + let x = x_array.value(i); + let y = y_array.value(i); + let row_id = row_id_array.value(i); + recovered_points.insert(row_id, (x, y)); + } + } + + // Verify all points were recovered + assert_eq!( + recovered_points.len(), + original_points.len(), + "Lost points with config max_points_per_leaf={}, batches_per_file={}", + max_points_per_leaf, + batches_per_file + ); + + // Verify each point matches exactly + for (original_x, original_y, row_id) in &original_points { + let (recovered_x, recovered_y) = recovered_points + .get(row_id) + .expect(&format!("Missing row_id {} in recovered data", row_id)); + + assert_eq!( + *recovered_x, *original_x, + "X coordinate mismatch for row_id {}: expected {}, got {}", + row_id, original_x, recovered_x + ); + assert_eq!( + *recovered_y, *original_y, + "Y coordinate mismatch for row_id {}: expected {}, got {}", + row_id, original_y, recovered_y + ); + } + } + } + + #[tokio::test] + async fn test_geo_index_no_duplicate_points() { + // Ensure no points are duplicated during write/read + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 7, // Odd number to test edge cases + batches_per_file: 3, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create 100 unique points + for i in 0..100 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Read all row_ids from all leaves + use crate::metrics::NoOpMetricsCollector; + let metrics = NoOpMetricsCollector {}; + let mut all_row_ids = Vec::new(); + + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + // Should have exactly 100 row_ids + assert_eq!(all_row_ids.len(), 100, "Wrong number of points recovered"); + + // All should be unique + let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); + assert_eq!(unique_row_ids.len(), 100, "Found duplicate points!"); + + // Should be exactly 0..99 + for i in 0..100 { + assert!(unique_row_ids.contains(&(i as u64)), "Missing row_id {}", i); + } + } + + #[tokio::test] + async fn test_geo_intersects_lazy_loading() { + use crate::metrics::NoOpMetricsCollector; + + // Test that leaves are loaded lazily (not all at once) + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create 100 points in a grid + for x in 0..10 { + for y in 0..10 { + builder + .points + .push((x as f64 * 10.0, y as f64 * 10.0, (x * 10 + y) as u64)); + } + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Query a small region - should only load relevant leaves, not all of them + let query = GeoQuery::Intersects(0.0, 0.0, 20.0, 20.0); + let metrics = NoOpMetricsCollector {}; + + // Find which leaves would be touched + let query_bbox = [0.0, 0.0, 20.0, 20.0]; + let intersecting_leaves = index.bkd_tree.find_intersecting_leaves(query_bbox).unwrap(); + + // Verify we're not loading ALL leaves for a small query + assert!( + intersecting_leaves.len() < index.bkd_tree.num_leaves as usize, + "Lazy loading test: small query should not touch all leaves! \ + Touched {}/{} leaves", + intersecting_leaves.len(), + index.bkd_tree.num_leaves + ); + + // Execute the query + let result = index.search(&query, &metrics).await.unwrap(); + + // Manually verify correctness: which points SHOULD be in bbox [0, 0, 20, 20]? + let mut expected_row_ids = std::collections::HashSet::new(); + for x in 0..10 { + for y in 0..10 { + let point_x = x as f64 * 10.0; + let point_y = y as f64 * 10.0; + // st_intersects: point is inside bbox if min_x <= x <= max_x && min_y <= y <= max_y + if point_x >= 0.0 && point_x <= 20.0 && point_y >= 0.0 && point_y <= 20.0 { + expected_row_ids.insert((x * 10 + y) as u64); + } + } + } + + // Verify results match our manual calculation + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_count = row_ids.len().unwrap_or(0) as usize; + + assert_eq!( + actual_count, + expected_row_ids.len(), + "Expected {} points, got {}", + expected_row_ids.len(), + actual_count + ); + + // Verify each returned row_id is in our expected set + if let Some(iter) = row_ids.row_ids() { + for row_addr in iter { + let row_id = u64::from(row_addr); + assert!( + expected_row_ids.contains(&row_id), + "Unexpected row_id {} in results", + row_id + ); + } + } + + // Verify we didn't miss any expected points + if let Some(iter) = row_ids.row_ids() { + let found_ids: std::collections::HashSet = + iter.map(|addr| u64::from(addr)).collect(); + for expected_id in &expected_row_ids { + assert!( + found_ids.contains(expected_id), + "Missing expected row_id {} in results", + expected_id + ); + } + } + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_invalid_coordinates() { + // Test that NaN and infinity coordinates are rejected (like Lucene does) + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + // Test NaN in X coordinate + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((f64::NAN, 20.0, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("NaN")); + + // Test NaN in Y coordinate + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((20.0, f64::NAN, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("NaN")); + + // Test Infinity in X coordinate + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((f64::INFINITY, 20.0, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Infinite")); + + // Test Negative Infinity in Y coordinate + let mut builder = + BkdTreeBuilder::try_new(params.clone(), DataType::Struct(Vec::::new().into())) + .unwrap(); + builder.points.push((10.0, 10.0, 1)); + builder.points.push((20.0, f64::NEG_INFINITY, 2)); + let result = builder.write_index(test_store.as_ref()).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Infinite")); + } + + #[tokio::test] + async fn test_geo_intersects_out_of_bounds_queries() { + use crate::metrics::NoOpMetricsCollector; + + // Test queries completely outside the data bounds + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Add points in region 0-100, 0-100 + for i in 0..50 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Query completely out of bounds (far away) + let query = GeoQuery::Intersects(1000.0, 1000.0, 2000.0, 2000.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let count = row_ids.len().unwrap_or(0) as usize; + assert_eq!(count, 0); + } + _ => panic!("Expected Exact search result"), + } + + // Query in negative space (if data is all positive) + let query = GeoQuery::Intersects(-500.0, -500.0, -100.0, -100.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let count = row_ids.len().unwrap_or(0) as usize; + assert_eq!(count, 0); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_extreme_coordinates() { + use crate::metrics::NoOpMetricsCollector; + + // Test with extremely large and small (but valid) coordinates + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Add points with extreme coordinates + builder.points.push((1e10, 1e10, 1)); + builder.points.push((-1e10, -1e10, 2)); + builder.points.push((1e-10, 1e-10, 3)); + builder.points.push((-1e-10, -1e-10, 4)); + builder.points.push((f64::MAX / 2.0, f64::MAX / 2.0, 5)); + builder.points.push((f64::MIN / 2.0, f64::MIN / 2.0, 6)); + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Query for large positive values + let query = GeoQuery::Intersects(1e10 - 1.0, 1e10 - 1.0, 1e10 + 1.0, 1e10 + 1.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids + .row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + assert!(actual_ids.contains(&1)); + } + _ => panic!("Expected Exact search result"), + } + + // Query for large negative values + let query = GeoQuery::Intersects(-1e10 - 1.0, -1e10 - 1.0, -1e10 + 1.0, -1e10 + 1.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids + .row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + assert!(actual_ids.contains(&2)); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_huge_query_bbox() { + use crate::metrics::NoOpMetricsCollector; + + // Test query larger than all data bounds + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Add points in small region + for i in 0..20 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Query that encompasses ALL data + let query = GeoQuery::Intersects(-1000.0, -1000.0, 1000.0, 1000.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let count = row_ids.len().unwrap_or(0) as usize; + // Should find all 20 points + assert_eq!(count, 20); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_zero_size_query() { + use crate::metrics::NoOpMetricsCollector; + + // Test query box with zero width/height (point query) + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Add some points + builder.points.push((10.0, 10.0, 1)); + builder.points.push((10.0, 10.0, 2)); // Duplicate at same location + builder.points.push((20.0, 20.0, 3)); + builder.points.push((30.0, 30.0, 4)); + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + let metrics = NoOpMetricsCollector {}; + + // Zero-width query (line) + let query = GeoQuery::Intersects(10.0, 10.0, 10.0, 20.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids + .row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + // Should find points at (10, 10) + assert!(actual_ids.contains(&1)); + assert!(actual_ids.contains(&2)); + } + _ => panic!("Expected Exact search result"), + } + + // Zero-height query (line) + let query = GeoQuery::Intersects(10.0, 10.0, 20.0, 10.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids + .row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + // Should find points at (10, 10) + assert!(actual_ids.contains(&1)); + assert!(actual_ids.contains(&2)); + } + _ => panic!("Expected Exact search result"), + } + + // Point query (zero width and height) + let query = GeoQuery::Intersects(20.0, 20.0, 20.0, 20.0); + let result = index.search(&query, &metrics).await.unwrap(); + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + let actual_ids: Vec = row_ids + .row_ids() + .map(|iter| iter.map(|addr| u64::from(addr)).collect()) + .unwrap_or_default(); + // Should find point at (20, 20) + assert_eq!(actual_ids.len(), 1); + assert!(actual_ids.contains(&3)); + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_intersects_large_scale_lazy_loading() { + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::sync::atomic::{AtomicUsize, Ordering}; + + // Custom metrics collector to track I/O operations + struct LoadTracker { + part_loads: AtomicUsize, + } + impl crate::metrics::MetricsCollector for LoadTracker { + fn record_part_load(&self) { + self.part_loads.fetch_add(1, Ordering::Relaxed); + } + fn record_parts_loaded(&self, _num_parts: usize) {} + fn record_index_loads(&self, _num_indices: usize) {} + fn record_comparisons(&self, _num_comparisons: usize) {} + } + + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 1000, // Realistic leaf size + batches_per_file: 10, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create 1 million random points across a 1000x1000 grid + let mut rng = StdRng::seed_from_u64(42); + for i in 0..1_000_000 { + let x = rng.random_range(0.0..1000.0); + let y = rng.random_range(0.0..1000.0); + builder.points.push((x, y, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Run 100 random queries with load tracking + let metrics = LoadTracker { + part_loads: AtomicUsize::new(0), + }; + let mut total_leaves_touched = 0; + + for _i in 0..100 { + // Random query bbox with varying sizes (1x1 to 50x50 regions in 1000x1000 space) + let width = rng.random_range(1.0..50.0); + let height = rng.random_range(1.0..50.0); + let min_x = rng.random_range(0.0..(1000.0 - width)); + let min_y = rng.random_range(0.0..(1000.0 - height)); + let max_x = min_x + width; + let max_y = min_y + height; + + let query = GeoQuery::Intersects(min_x, min_y, max_x, max_y); + + // Count leaves touched + let query_bbox = [min_x, min_y, max_x, max_y]; + let intersecting_leaves = index.bkd_tree.find_intersecting_leaves(query_bbox).unwrap(); + total_leaves_touched += intersecting_leaves.len(); + + // Execute query + let _result = index.search(&query, &metrics).await.unwrap(); + } + + let total_io_ops = metrics.part_loads.load(Ordering::Relaxed); + + // CRITICAL: Verify lazy loading is working! + // We should NOT load all leaves - only the ones intersecting query bboxes + assert!( + total_io_ops < index.bkd_tree.num_leaves as usize, + "❌ LAZY LOADING FAILED: Loaded {} leaves but index only has {} leaves! \ + Should load much fewer than total.", + total_io_ops, + index.bkd_tree.num_leaves + ); + + // Verify lazy loading is effective (< 10% of leaves touched on average per query) + let avg_leaves_touched = total_leaves_touched as f64 / 100.0; + let total_leaves = index.bkd_tree.num_leaves as f64; + assert!( + avg_leaves_touched < total_leaves * 0.1, + "❌ Lazy loading ineffective: touching {:.1}% of leaves on average", + (avg_leaves_touched / total_leaves) * 100.0 + ); + } + + #[tokio::test] + async fn test_geo_index_points_in_correct_leaves() { + // Verify points are in leaves with correct bounding boxes + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create a grid of points + for x in 0..20 { + for y in 0..20 { + builder + .points + .push((x as f64, y as f64, (x * 20 + y) as u64)); + } + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Verify each point is within its leaf's bounding box + use crate::metrics::NoOpMetricsCollector; + let metrics = NoOpMetricsCollector {}; + + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + + let x_array = leaf_data + .column(0) + .as_primitive::(); + let y_array = leaf_data + .column(1) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + let x = x_array.value(i); + let y = y_array.value(i); + + // Point must be within leaf's bounding box + assert!( + x >= leaf.bounds[0] && x <= leaf.bounds[2], + "Point x={} outside leaf bounds [{}, {}]", + x, + leaf.bounds[0], + leaf.bounds[2] + ); + assert!( + y >= leaf.bounds[1] && y <= leaf.bounds[3], + "Point y={} outside leaf bounds [{}, {}]", + y, + leaf.bounds[1], + leaf.bounds[3] + ); + } + } + } + + #[tokio::test] + async fn test_geo_index_duplicate_coordinates_different_row_ids() { + // Test that duplicate coordinates with different row_ids are all stored correctly + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, + batches_per_file: 3, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create multiple points at the same coordinates with different row_ids + // This simulates real-world scenarios where multiple records have the same location + for i in 0..20 { + // 5 points at location (10.0, 10.0) with different row_ids + builder.points.push((10.0, 10.0, i as u64)); + } + + // Add some other points at different locations + for i in 20..50 { + builder.points.push((i as f64, i as f64, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Query the region containing the duplicates + let query = GeoQuery::Intersects(9.0, 9.0, 11.0, 11.0); + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + // Should find all 20 points at (10.0, 10.0) + assert_eq!( + row_ids.len().unwrap_or(0), + 20, + "Expected 20 duplicate coordinate entries" + ); + + // Verify all row_ids 0..19 are present + for i in 0..20 { + assert!( + row_ids.contains(i as u64), + "Missing row_id {} for duplicate coordinate", + i + ); + } + } + _ => panic!("Expected Exact search result"), + } + + // Verify data integrity: all 50 points should be stored + let mut all_row_ids = Vec::new(); + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + assert_eq!( + all_row_ids.len(), + 50, + "Should store all 50 points including duplicates" + ); + + let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); + assert_eq!(unique_row_ids.len(), 50, "All 50 row_ids should be unique"); + } + + #[tokio::test] + async fn test_geo_index_many_duplicates_at_same_location() { + // Test with many duplicate coordinates to ensure they don't cause issues + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 20, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create 100 points all at the exact same location + for i in 0..100 { + builder.points.push((50.0, 50.0, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index) + .await + .expect("BKD tree validation failed with many duplicates"); + + // Query should return all 100 points + let query = GeoQuery::Intersects(49.0, 49.0, 51.0, 51.0); + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!( + row_ids.len().unwrap_or(0), + 100, + "Expected all 100 duplicate coordinate entries" + ); + + // Verify all row_ids are present + for i in 0..100 { + assert!( + row_ids.contains(i as u64), + "Missing row_id {} in duplicate set", + i + ); + } + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_index_duplicates_across_multiple_locations() { + // Test duplicates at multiple different locations + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 15, + batches_per_file: 3, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create clusters of duplicates at different locations + let locations = vec![(10.0, 10.0), (20.0, 20.0), (30.0, 30.0), (40.0, 40.0)]; + + let mut row_id = 0u64; + for (x, y) in &locations { + // 10 duplicates at each location + for _ in 0..10 { + builder.points.push((*x, *y, row_id)); + row_id += 1; + } + } + + // Add some unique points + for i in 0..20 { + builder.points.push((i as f64, i as f64 + 50.0, row_id)); + row_id += 1; + } + + let total_points = row_id; + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index) + .await + .expect("BKD tree validation failed with duplicates at multiple locations"); + + let metrics = NoOpMetricsCollector {}; + + // Query each location cluster + for (i, (x, y)) in locations.iter().enumerate() { + let query = GeoQuery::Intersects(x - 1.0, y - 1.0, x + 1.0, y + 1.0); + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!( + row_ids.len().unwrap_or(0), + 10, + "Expected 10 duplicates at location {}", + i + ); + + // Verify the correct row_ids for this cluster + let expected_start = (i * 10) as u64; + let expected_end = expected_start + 10; + for expected_id in expected_start..expected_end { + assert!( + row_ids.contains(expected_id), + "Missing row_id {} in cluster at ({}, {})", + expected_id, + x, + y + ); + } + } + _ => panic!("Expected Exact search result"), + } + } + + // Verify total count + let mut all_row_ids = Vec::new(); + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + assert_eq!( + all_row_ids.len(), + total_points as usize, + "Should store all {} points including duplicates", + total_points + ); + } + + #[tokio::test] + async fn test_geo_index_duplicate_handling_with_leaf_splits() { + // Test that duplicates are handled correctly when they cause leaf splits + use crate::metrics::NoOpMetricsCollector; + + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 10, // Small leaf size to force splits + batches_per_file: 3, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + // Create 50 points at the same location - should span multiple leaves + for i in 0..50 { + builder.points.push((25.0, 25.0, i as u64)); + } + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index) + .await + .expect("BKD tree validation failed with duplicate-induced splits"); + + // Query should return all 50 points + let query = GeoQuery::Intersects(24.0, 24.0, 26.0, 26.0); + let metrics = NoOpMetricsCollector {}; + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!( + row_ids.len().unwrap_or(0), + 50, + "Expected all 50 duplicate entries after leaf splits" + ); + + // Verify all row_ids are present + for i in 0..50 { + assert!( + row_ids.contains(i as u64), + "Missing row_id {} after leaf split", + i + ); + } + } + _ => panic!("Expected Exact search result"), + } + } + + #[tokio::test] + async fn test_geo_index_mixed_duplicates_and_unique_points() { + // Realistic test: mix of unique points and duplicates + use crate::metrics::NoOpMetricsCollector; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + let test_store = create_test_store(); + + let params = BkdTreeBuilderParams { + max_points_per_leaf: 20, + batches_per_file: 5, + }; + + let mut builder = + BkdTreeBuilder::try_new(params, DataType::Struct(Vec::::new().into())).unwrap(); + + let mut rng = StdRng::seed_from_u64(123); + let mut row_id = 0u64; + + // Add 100 random unique points + for _ in 0..100 { + let x = rng.random_range(0.0..100.0); + let y = rng.random_range(0.0..100.0); + builder.points.push((x, y, row_id)); + row_id += 1; + } + + // Add 20 duplicates at a specific location + for _ in 0..20 { + builder.points.push((50.0, 50.0, row_id)); + row_id += 1; + } + + // Add 50 more random unique points + for _ in 0..50 { + let x = rng.random_range(0.0..100.0); + let y = rng.random_range(0.0..100.0); + builder.points.push((x, y, row_id)); + row_id += 1; + } + + let total_points = 170; + + builder.write_index(test_store.as_ref()).await.unwrap(); + let index = BkdTree::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Validate tree structure + validate_bkd_tree(&index) + .await + .expect("BKD tree validation failed with mixed duplicates and unique points"); + + let metrics = NoOpMetricsCollector {}; + + // Query the duplicate cluster + let query = GeoQuery::Intersects(49.0, 49.0, 51.0, 51.0); + let result = index.search(&query, &metrics).await.unwrap(); + + match result { + crate::scalar::SearchResult::Exact(row_ids) => { + assert_eq!( + row_ids.len().unwrap_or(0), + 20, + "Expected 20 duplicate entries at (50, 50)" + ); + + // Verify the duplicate row_ids (100..119) + for expected_id in 100..120 { + assert!( + row_ids.contains(expected_id as u64), + "Missing row_id {} in duplicate cluster", + expected_id + ); + } + } + _ => panic!("Expected Exact search result"), + } + + // Verify total count + let mut all_row_ids = Vec::new(); + for leaf in index.bkd_tree.nodes.iter().filter_map(|n| n.as_leaf()) { + let leaf_data = index.load_leaf(leaf, &metrics).await.unwrap(); + let row_id_array = leaf_data + .column(2) + .as_primitive::(); + + for i in 0..leaf_data.num_rows() { + all_row_ids.push(row_id_array.value(i)); + } + } + + assert_eq!( + all_row_ids.len(), + total_points, + "Should store all {} points", + total_points + ); + + // Verify all row_ids are unique (no double-counting) + let unique_row_ids: std::collections::HashSet = all_row_ids.iter().copied().collect(); + assert_eq!( + unique_row_ids.len(), + total_points, + "All row_ids should be unique" + ); + } +} diff --git a/rust/lance-index/src/scalar/geo/mod.rs b/rust/lance-index/src/scalar/geo/mod.rs new file mode 100644 index 00000000000..a998c156784 --- /dev/null +++ b/rust/lance-index/src/scalar/geo/mod.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Geographic indexing module +//! +//! This module contains implementations for spatial/geographic indexing: +//! - BKD: Block K-Dimensional tree for efficient spatial partitioning (core data structure) +//! - BkdTree: Geographic index built on top of BKD trees for GeoArrow data + +pub mod bkd; +pub mod bkdtree; + +pub use bkd::*; +pub use bkdtree::*; diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index a36e221f6a0..b083f7fb6f9 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -15,9 +15,9 @@ use crate::{ frag_reuse::FragReuseIndex, scalar::{ bitmap::BitmapIndexPlugin, bloomfilter::BloomFilterIndexPlugin, btree::BTreeIndexPlugin, - expression::ScalarQueryParser, inverted::InvertedIndexPlugin, json::JsonIndexPlugin, - label_list::LabelListIndexPlugin, ngram::NGramIndexPlugin, zonemap::ZoneMapIndexPlugin, - CreatedIndex, IndexStore, ScalarIndex, + expression::ScalarQueryParser, geo::BkdTreePlugin, inverted::InvertedIndexPlugin, + json::JsonIndexPlugin, label_list::LabelListIndexPlugin, ngram::NGramIndexPlugin, + zonemap::ZoneMapIndexPlugin, CreatedIndex, IndexStore, ScalarIndex, }, }; @@ -201,6 +201,7 @@ impl ScalarIndexPluginRegistry { registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); + registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); registry.add_plugin::(); diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index 4c58375619e..898042f514c 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -65,6 +65,8 @@ impl SqlQueryBuilder { pub async fn build(self) -> lance_core::Result { let ctx = SessionContext::new(); + // Register Lance UDFs + lance_datafusion::udf::register_functions(&ctx); let row_id = self.with_row_id; let row_addr = self.with_row_addr; ctx.register_table( diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index 76f1fba3d34..06d94e201d8 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -151,7 +151,8 @@ impl<'a> CreateIndexBuilder<'a> { | IndexType::NGram | IndexType::ZoneMap | IndexType::BloomFilter - | IndexType::LabelList, + | IndexType::LabelList + | IndexType::BkdTree, LANCE_SCALAR_INDEX, ) => { let base_params = ScalarIndexParams::for_builtin(self.index_type.try_into()?);