From 8d541eefbcfe80c3e5b7d468e63d9beb14dd973a Mon Sep 17 00:00:00 2001 From: xloya Date: Tue, 16 Dec 2025 13:33:58 +0800 Subject: [PATCH] add code --- python/python/lance/dataset.py | 48 ++++++++++++- python/python/tests/test_vector_index.py | 85 ++++++++++++++++++++++++ python/src/dataset.rs | 34 ++++++++++ 3 files changed, 166 insertions(+), 1 deletion(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 87bbfefe422..a72954aa80a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -741,7 +741,8 @@ def scanner( "k": 10, "minimum_nprobes": 1, "maximum_nprobes": 50, - "refine_factor": 1 + "refine_factor": 1, + "distance_range": (0.0, 1.0), } batch_size: int, default None @@ -4694,7 +4695,44 @@ def nearest( refine_factor: Optional[int] = None, use_index: bool = True, ef: Optional[int] = None, + distance_range: Optional[tuple[Optional[float], Optional[float]]] = None, ) -> ScannerBuilder: + """Configure nearest neighbor search. + + Parameters + ---------- + column: str + The name of the vector column to search. + q: QueryVectorLike + The query vector. + k: int, optional + The number of nearest neighbors to return. + metric: str, optional + The distance metric to use (e.g., "L2", "cosine", "dot", "hamming"). + nprobes: int, optional + The number of partitions to search. Sets both minimum_nprobes and + maximum_nprobes to the same value. + minimum_nprobes: int, optional + The minimum number of partitions to search. + maximum_nprobes: int, optional + The maximum number of partitions to search. + refine_factor: int, optional + The refine factor for the search. + use_index: bool, default True + Whether to use the index for the search. + ef: int, optional + The ef parameter for HNSW search. + distance_range: tuple[Optional[float], Optional[float]], optional + A tuple of (lower_bound, upper_bound) to filter results by distance. + Both bounds are optional. The lower bound is inclusive and the upper + bound is exclusive, so (0.0, 1.0) keeps distances d where + 0.0 <= d < 1.0, (None, 0.5) keeps d < 0.5, and (0.5, None) keeps d >= 0.5. + + Returns + ------- + ScannerBuilder + The scanner builder for method chaining. + """ q, q_dim = _coerce_query_vector(q) lance_field = self.ds._ds.lance_schema.field(column) @@ -4751,6 +4789,13 @@ def nearest( # `ef` should be >= `k`, but `k` could be None so we can't check it here # the rust code will check it raise ValueError(f"ef must be > 0 but got {ef}") + + if distance_range is not None: + if len(distance_range) != 2: + raise ValueError( + "distance_range must be a tuple of (lower_bound, upper_bound)" + ) + self._nearest = { "column": column, "q": q, @@ -4761,6 +4806,7 @@ def nearest( "refine_factor": refine_factor, "use_index": use_index, "ef": ef, + "distance_range": distance_range, } return self diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index d66f91f0831..9616cc8446d 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1912,3 +1912,88 @@ def func(rs: pa.Table): assert rs["vector"][0].as_py() == q run(dataset, q=np.array(q), assert_func=func) + + +def test_vector_index_distance_range(tmp_path): + """Ensure vector index honors distance_range.""" + ndim = 128 + rng = np.random.default_rng(seed=42) + base = rng.standard_normal((509, ndim)).astype(np.float32) + zero_vec = np.zeros((1, ndim), dtype=np.float32) + near_vec = np.full((1, ndim), 0.01, dtype=np.float32) + far_vec = np.full((1, ndim), 500.0, dtype=np.float32) + matrix = np.concatenate([zero_vec, near_vec, far_vec, base], axis=0) + tbl = vec_to_table(data=matrix).append_column( + "id", pa.array(np.arange(matrix.shape[0], dtype=np.int64)) + ) + dataset = lance.write_dataset(tbl, tmp_path / "vrange") + indexed = dataset.create_index("vector", index_type="IVF_FLAT", num_partitions=4) + + q = zero_vec[0] + distance_range = (0.0, 0.5) + nprobes_all = 4 + + # Brute force baseline (exact): + # get full distance distribution and build expected in-range ids. + all_results = indexed.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": q, + "k": matrix.shape[0], + "use_index": False, + }, + ) + all_distances = all_results["_distance"].to_numpy() + assert len(all_distances) == matrix.shape[0] + assert all_distances.min() == 0.0 + assert ( + all_distances.max() > distance_range[1] + ) # ensure some values are out of range + + in_range_mask = (all_distances >= distance_range[0]) & ( + all_distances < distance_range[1] + ) + expected_ids = set(all_results["id"].to_numpy()[in_range_mask].tolist()) + assert len(expected_ids) > 0 + + # Compare distance_range results: + # brute-force vs index path should match exactly for IVF_FLAT + brute_results = indexed.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": q, + "k": matrix.shape[0], + "distance_range": distance_range, + "use_index": False, + }, + ) + + index_results = indexed.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": q, + "k": matrix.shape[0], + "distance_range": distance_range, + "nprobes": nprobes_all, + }, + ) + + brute_ids = brute_results["id"].to_numpy() + index_ids = index_results["id"].to_numpy() + brute_distances = brute_results["_distance"].to_numpy() + index_distances = index_results["_distance"].to_numpy() + + assert set(brute_ids.tolist()).issubset(expected_ids) + assert set(index_ids.tolist()).issubset(expected_ids) + assert len(brute_ids) == len(index_ids) + assert np.array_equal(brute_ids, index_ids) + assert np.all(brute_distances >= distance_range[0]) and np.all( + brute_distances < distance_range[1] + ) + assert np.all(index_distances >= distance_range[0]) and np.all( + index_distances < distance_range[1] + ) + assert np.allclose(brute_distances, index_distances, rtol=0.0, atol=0.0) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index ade2b4516ca..e92c5d34c9f 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1079,6 +1079,37 @@ impl Dataset { } _ => scanner.nearest(&column, &q, k), }; + let distance_range: Option<(Option, Option)> = + if let Some(dr) = nearest.get_item("distance_range")? { + if dr.is_none() { + None + } else { + let tuple = dr + .downcast::() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + if tuple.len() != 2 { + return Err(PyValueError::new_err( + "distance_range must be a tuple of (lower_bound, upper_bound)", + )); + } + let lower_any = tuple.get_item(0)?; + let lower = if lower_any.is_none() { + None + } else { + Some(lower_any.extract()?) + }; + let upper_any = tuple.get_item(1)?; + let upper = if upper_any.is_none() { + None + } else { + Some(upper_any.extract()?) + }; + Some((lower, upper)) + } + } else { + None + }; + scanner .map(|s| { let mut s = s.minimum_nprobes(minimum_nprobes); @@ -1095,6 +1126,9 @@ impl Dataset { s = s.ef(ef); } s.use_index(use_index); + if let Some((lower, upper)) = distance_range { + s.distance_range(lower, upper); + } s }) .map_err(|err| PyValueError::new_err(err.to_string()))?;