Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ def scanner(
"k": 10,
"minimum_nprobes": 1,
"maximum_nprobes": 50,
"refine_factor": 1
"refine_factor": 1,
"distance_range": (0.0, 1.0),
}

batch_size: int, default None
Expand Down Expand Up @@ -4690,7 +4691,44 @@ def nearest(
refine_factor: Optional[int] = None,
use_index: bool = True,
ef: Optional[int] = None,
distance_range: Optional[tuple[Optional[float], Optional[float]]] = None,
) -> ScannerBuilder:
"""Configure nearest neighbor search.

Parameters
----------
column: str
The name of the vector column to search.
q: QueryVectorLike
The query vector.
k: int, optional
The number of nearest neighbors to return.
metric: str, optional
The distance metric to use (e.g., "L2", "cosine", "dot", "hamming").
nprobes: int, optional
The number of partitions to search. Sets both minimum_nprobes and
maximum_nprobes to the same value.
minimum_nprobes: int, optional
The minimum number of partitions to search.
maximum_nprobes: int, optional
The maximum number of partitions to search.
refine_factor: int, optional
The refine factor for the search.
use_index: bool, default True
Whether to use the index for the search.
ef: int, optional
The ef parameter for HNSW search.
distance_range: tuple[Optional[float], Optional[float]], optional
A tuple of (lower_bound, upper_bound) to filter results by distance.
Both bounds are optional. The lower bound is inclusive and the upper
bound is exclusive, so (0.0, 1.0) keeps distances d where
0.0 <= d < 1.0, (None, 0.5) keeps d < 0.5, and (0.5, None) keeps d >= 0.5.

Returns
-------
ScannerBuilder
The scanner builder for method chaining.
"""
q, q_dim = _coerce_query_vector(q)

lance_field = self.ds._ds.lance_schema.field_case_insensitive(column)
Expand Down Expand Up @@ -4747,6 +4785,13 @@ def nearest(
# `ef` should be >= `k`, but `k` could be None so we can't check it here
# the rust code will check it
raise ValueError(f"ef must be > 0 but got {ef}")

if distance_range is not None:
if len(distance_range) != 2:
raise ValueError(
"distance_range must be a tuple of (lower_bound, upper_bound)"
)

self._nearest = {
"column": column,
"q": q,
Expand All @@ -4757,6 +4802,7 @@ def nearest(
"refine_factor": refine_factor,
"use_index": use_index,
"ef": ef,
"distance_range": distance_range,
}
return self

Expand Down
85 changes: 85 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1912,3 +1912,88 @@ def func(rs: pa.Table):
assert rs["vector"][0].as_py() == q

run(dataset, q=np.array(q), assert_func=func)


def test_vector_index_distance_range(tmp_path):
"""Ensure vector index honors distance_range."""
ndim = 128
rng = np.random.default_rng(seed=42)
base = rng.standard_normal((509, ndim)).astype(np.float32)
zero_vec = np.zeros((1, ndim), dtype=np.float32)
near_vec = np.full((1, ndim), 0.01, dtype=np.float32)
far_vec = np.full((1, ndim), 500.0, dtype=np.float32)
matrix = np.concatenate([zero_vec, near_vec, far_vec, base], axis=0)
tbl = vec_to_table(data=matrix).append_column(
"id", pa.array(np.arange(matrix.shape[0], dtype=np.int64))
)
dataset = lance.write_dataset(tbl, tmp_path / "vrange")
indexed = dataset.create_index("vector", index_type="IVF_FLAT", num_partitions=4)

q = zero_vec[0]
distance_range = (0.0, 0.5)
nprobes_all = 4

# Brute force baseline (exact):
# get full distance distribution and build expected in-range ids.
all_results = indexed.to_table(
columns=["id"],
nearest={
"column": "vector",
"q": q,
"k": matrix.shape[0],
"use_index": False,
},
)
all_distances = all_results["_distance"].to_numpy()
assert len(all_distances) == matrix.shape[0]
assert all_distances.min() == 0.0
assert (
all_distances.max() > distance_range[1]
) # ensure some values are out of range

in_range_mask = (all_distances >= distance_range[0]) & (
all_distances < distance_range[1]
)
expected_ids = set(all_results["id"].to_numpy()[in_range_mask].tolist())
assert len(expected_ids) > 0

# Compare distance_range results:
# brute-force vs index path should match exactly for IVF_FLAT
brute_results = indexed.to_table(
columns=["id"],
nearest={
"column": "vector",
"q": q,
"k": matrix.shape[0],
"distance_range": distance_range,
"use_index": False,
},
)

index_results = indexed.to_table(
columns=["id"],
nearest={
"column": "vector",
"q": q,
"k": matrix.shape[0],
"distance_range": distance_range,
"nprobes": nprobes_all,
},
)

brute_ids = brute_results["id"].to_numpy()
index_ids = index_results["id"].to_numpy()
brute_distances = brute_results["_distance"].to_numpy()
index_distances = index_results["_distance"].to_numpy()

assert set(brute_ids.tolist()).issubset(expected_ids)
assert set(index_ids.tolist()).issubset(expected_ids)
assert len(brute_ids) == len(index_ids)
assert np.array_equal(brute_ids, index_ids)
assert np.all(brute_distances >= distance_range[0]) and np.all(
brute_distances < distance_range[1]
)
assert np.all(index_distances >= distance_range[0]) and np.all(
index_distances < distance_range[1]
)
assert np.allclose(brute_distances, index_distances, rtol=0.0, atol=0.0)
34 changes: 34 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,37 @@ impl Dataset {
}
_ => scanner.nearest(&column, &q, k),
};
let distance_range: Option<(Option<f32>, Option<f32>)> =
if let Some(dr) = nearest.get_item("distance_range")? {
if dr.is_none() {
None
} else {
let tuple = dr
.downcast::<PyTuple>()
.map_err(|err| PyValueError::new_err(err.to_string()))?;
if tuple.len() != 2 {
return Err(PyValueError::new_err(
"distance_range must be a tuple of (lower_bound, upper_bound)",
));
}
let lower_any = tuple.get_item(0)?;
let lower = if lower_any.is_none() {
None
} else {
Some(lower_any.extract()?)
};
let upper_any = tuple.get_item(1)?;
let upper = if upper_any.is_none() {
None
} else {
Some(upper_any.extract()?)
};
Some((lower, upper))
}
} else {
None
};

scanner
.map(|s| {
let mut s = s.minimum_nprobes(minimum_nprobes);
Expand All @@ -1096,6 +1127,9 @@ impl Dataset {
s = s.ef(ef);
}
s.use_index(use_index);
if let Some((lower, upper)) = distance_range {
s.distance_range(lower, upper);
}
s
})
.map_err(|err| PyValueError::new_err(err.to_string()))?;
Expand Down