From 6a7ea2d76b150fbb4136da8aa258f86317cb3b4b Mon Sep 17 00:00:00 2001 From: jinglun Date: Tue, 16 Dec 2025 15:33:39 +0800 Subject: [PATCH] feat(python): expose search_filter in scanner --- python/python/lance/dataset.py | 378 +++++++++++++++-------- python/python/lance/lance/__init__.pyi | 2 + python/python/tests/test_scalar_index.py | 119 +++++++ python/python/tests/test_vector_index.py | 126 ++++++++ python/src/dataset.rs | 319 ++++++++++++------- python/src/lib.rs | 3 +- 6 files changed, 718 insertions(+), 229 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 97a9a0c6775..e0344b64ca3 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -55,6 +55,7 @@ DatasetBasePath, IOStats, LanceSchema, + PySearchFilter, ScanStatistics, _Dataset, _MergeInsertBuilder, @@ -679,7 +680,9 @@ def _apply_default_scan_options(self, builder: ScannerBuilder): def scanner( self, columns: Optional[Union[List[str], Dict[str, str]]] = None, - filter: Optional[Union[str, pa.compute.Expression]] = None, + filter: Optional[ + Union[str, pa.compute.Expression, FullTextQuery, VectorSearchQuery, dict] + ] = None, limit: Optional[int] = None, offset: Optional[int] = None, nearest: Optional[dict] = None, @@ -715,10 +718,50 @@ def scanner( List of column names to be fetched. Or a dictionary of column names to SQL expressions. All columns are fetched if None or unspecified. - filter: pa.compute.Expression or str - Expression or str that is a valid SQL where clause. See - `Lance filter pushdown `_ - for valid SQL expressions. + filter: pa.compute.Expression, str, VectorSearchQuery, FullTextQuery or dict + Lance supports 2 kinds of filters: expression filter and search filter. + + - Expression filter is pa.compute.Expression or str that is a valid SQL + where clause. See `Lance filter pushdown + `_ + for valid SQL expressions. Expression filter is applied to filtered scan, + full text search and vector search. + + - VectorSearchQuery is a vector search that can only be applied to full + text search. Example: + .. code-block:: python + + filter=VectorSearchQuery( + "vector", + np.array([12, 17, 300, 10], dtype=np.float32), + 5, + 20, + True, + ) + + - FullTextQuery is a full text search that can only be applied to vector + search. Example: + .. code-block:: python + + filter=PhraseQuery("hello world", "col") + + - Dictionary is a combined filter containing both expression filter with + key `expr_filter` and search filter with key `search_filter`. Example: + .. code-block:: python + + scanner = ds.scanner( + nearest={ + "column": "vector", + "q": np.array([12, 17, 300, 10], dtype=np.float32), + "k": 5, + "minimum_nprobes": 20, + "use_index": True, + }, + filter={ + "expr_filter": "category='geography'", + "search_filter": PhraseQuery("hello world", "col"), + }, + ) limit: int, default None Fetch up to this many rows. All rows if None or unspecified. offset: int, default None @@ -4622,6 +4665,7 @@ def __init__(self, ds: LanceDataset): self.ds = ds self._limit = None self._filter = None + self._search_filter = None self._substrait_filter = None self._prefilter = False self._late_materialization = None @@ -4747,8 +4791,27 @@ def columns( ) return self - def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder: - if isinstance(filter, pa.compute.Expression): + def filter( + self, + filter: Union[ + str, pa.compute.Expression, FullTextQuery, VectorSearchQuery, dict + ], + ) -> ScannerBuilder: + """ + Add a filter to the scanner. + + :param filter: The filter to apply. This can be a string, a pyarrow compute + expression, a FullTextQuery, a VectorSearchQuery, or a dictionary. + + :return: The scanner builder. + """ + if isinstance(filter, FullTextQuery): + self._search_filter = PySearchFilter.from_full_text_query(filter.inner) + elif isinstance(filter, VectorSearchQuery): + self._search_filter = PySearchFilter.from_vector_search_query(filter.inner) + elif isinstance(filter, str): + self._filter = filter + elif isinstance(filter, pa.compute.Expression): try: from pyarrow.substrait import serialize_expressions @@ -4769,8 +4832,9 @@ def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder: ) else: fields_without_lists.append(field) - # Serialize the pyarrow compute expression toSubstrait and use - # that as a filter. + # Serialize the pyarrow compute expression toSubstrait and use + # that as a filter. + counter += 1 scalar_schema = pa.schema(fields_without_lists) substrait_filter = serialize_expressions( [filter], ["my_filter"], scalar_schema @@ -4790,7 +4854,14 @@ def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder: # stringifying the expression if pyarrow is too old self._filter = str(filter) else: - self._filter = filter + expr_filter = filter.get("expr_filter") + if expr_filter is not None: + self.filter(expr_filter) + + search_filter = filter.get("search_filter") + if search_filter is not None: + self.filter(search_filter) + return self def prefilter(self, prefilter: bool) -> ScannerBuilder: @@ -4888,117 +4959,20 @@ def nearest( ef: Optional[int] = None, distance_range: Optional[tuple[Optional[float], Optional[float]]] = None, ) -> ScannerBuilder: - """Configure nearest neighbor search. - - Parameters - ---------- - column: str - The name of the vector column to search. - q: QueryVectorLike - The query vector. - k: int, optional - The number of nearest neighbors to return. - metric: str, optional - The distance metric to use (e.g., "L2", "cosine", "dot", "hamming"). - nprobes: int, optional - The number of partitions to search. Sets both minimum_nprobes and - maximum_nprobes to the same value. - minimum_nprobes: int, optional - The minimum number of partitions to search. - maximum_nprobes: int, optional - The maximum number of partitions to search. - refine_factor: int, optional - The refine factor for the search. - use_index: bool, default True - Whether to use the index for the search. - ef: int, optional - The ef parameter for HNSW search. - distance_range: tuple[Optional[float], Optional[float]], optional - A tuple of (lower_bound, upper_bound) to filter results by distance. - Both bounds are optional. The lower bound is inclusive and the upper - bound is exclusive, so (0.0, 1.0) keeps distances d where - 0.0 <= d < 1.0, (None, 0.5) keeps d < 0.5, and (0.5, None) keeps d >= 0.5. - - Returns - ------- - ScannerBuilder - The scanner builder for method chaining. - """ - q, q_dim = _coerce_query_vector(q) - - lance_field = self.ds._ds.lance_schema.field_case_insensitive(column) - if lance_field is None: - raise ValueError(f"Embedding column {column} is not in the dataset") - - column_field = lance_field.to_arrow() - column_type = column_field.type - if hasattr(column_type, "storage_type"): - column_type = column_type.storage_type - if pa.types.is_fixed_size_list(column_type): - dim = column_type.list_size - elif pa.types.is_list(column_type) and pa.types.is_fixed_size_list( - column_type.value_type - ): - dim = column_type.value_type.list_size - else: - raise TypeError( - f"Query column {column} must be a vector. Got {column_field.type}." - ) - - if q_dim != dim: - raise ValueError( - f"Query vector size {len(q)} does not match index column size {dim}" - ) - - if k is not None and int(k) <= 0: - raise ValueError(f"Nearest-K must be > 0 but got {k}") - if nprobes is not None and int(nprobes) <= 0: - raise ValueError(f"Nprobes must be > 0 but got {nprobes}") - if minimum_nprobes is not None and int(minimum_nprobes) < 0: - raise ValueError(f"Minimum nprobes must be >= 0 but got {minimum_nprobes}") - if maximum_nprobes is not None and int(maximum_nprobes) < 0: - raise ValueError(f"Maximum nprobes must be >= 0 but got {maximum_nprobes}") - - if nprobes is not None: - if minimum_nprobes is not None or maximum_nprobes is not None: - raise ValueError( - "nprobes cannot be set in combination with minimum_nprobes or " - "maximum_nprobes" - ) - else: - minimum_nprobes = nprobes - maximum_nprobes = nprobes - if ( - minimum_nprobes is not None - and maximum_nprobes is not None - and minimum_nprobes > maximum_nprobes - ): - raise ValueError("minimum_nprobes must be <= maximum_nprobes") - if refine_factor is not None and int(refine_factor) < 1: - raise ValueError(f"Refine factor must be 1 or more got {refine_factor}") - if ef is not None and int(ef) <= 0: - # `ef` should be >= `k`, but `k` could be None so we can't check it here - # the rust code will check it - raise ValueError(f"ef must be > 0 but got {ef}") - - if distance_range is not None: - if len(distance_range) != 2: - raise ValueError( - "distance_range must be a tuple of (lower_bound, upper_bound)" - ) - - self._nearest = { - "column": column, - "q": q, - "k": k, - "metric": metric, - "minimum_nprobes": minimum_nprobes, - "maximum_nprobes": maximum_nprobes, - "refine_factor": refine_factor, - "use_index": use_index, - "ef": ef, - "distance_range": distance_range, - } + self._nearest = _build_vector_search_query( + column, + q, + dataset=self.ds, + k=k, + metric=metric, + nprobes=nprobes, + minimum_nprobes=minimum_nprobes, + maximum_nprobes=maximum_nprobes, + refine_factor=refine_factor, + use_index=use_index, + ef=ef, + distance_range=distance_range, + ) return self def fast_search(self, flag: bool) -> ScannerBuilder: @@ -5095,6 +5069,7 @@ def to_scanner(self) -> LanceScanner: self._columns, self._columns_with_transform, self._filter, + self._search_filter, self._prefilter, self._limit, self._offset, @@ -5942,6 +5917,134 @@ def _coerce_query_vector(query: QueryVectorLike) -> tuple[pa.Array, int]: return (query, len(query)) +def _build_vector_search_query( + column: str, + q, + *, + dataset: Optional["LanceDataset"] = None, + k: Optional[int] = None, + metric: Optional[str] = None, + nprobes: Optional[int] = None, + minimum_nprobes: Optional[int] = None, + maximum_nprobes: Optional[int] = None, + refine_factor: Optional[int] = None, + use_index: bool = True, + ef: Optional[int] = None, + distance_range: Optional[tuple[Optional[float], Optional[float]]] = None, +) -> dict: + """Configure nearest neighbor search. + + Parameters + ---------- + column: str + The name of the vector column to search. + q: QueryVectorLike + The query vector. + k: int, optional + The number of nearest neighbors to return. + metric: str, optional + The distance metric to use (e.g., "L2", "cosine", "dot", "hamming"). + nprobes: int, optional + The number of partitions to search. Sets both minimum_nprobes and + maximum_nprobes to the same value. + minimum_nprobes: int, optional + The minimum number of partitions to search. + maximum_nprobes: int, optional + The maximum number of partitions to search. + refine_factor: int, optional + The refine factor for the search. + use_index: bool, default True + Whether to use the index for the search. + ef: int, optional + The ef parameter for HNSW search. + distance_range: tuple[Optional[float], Optional[float]], optional + A tuple of (lower_bound, upper_bound) to filter results by distance. + Both bounds are optional. The lower bound is inclusive and the upper + bound is exclusive, so (0.0, 1.0) keeps distances d where + 0.0 <= d < 1.0, (None, 0.5) keeps d < 0.5, and (0.5, None) keeps d >= 0.5. + + Returns + ------- + ScannerBuilder + The scanner builder for method chaining. + """ + q, q_dim = _coerce_query_vector(q) + + lance_field = dataset._ds.lance_schema.field_case_insensitive(column) + if lance_field is None: + raise ValueError(f"Embedding column {column} is not in the dataset") + + column_field = lance_field.to_arrow() + column_type = column_field.type + if hasattr(column_type, "storage_type"): + column_type = column_type.storage_type + if pa.types.is_fixed_size_list(column_type): + dim = column_type.list_size + elif pa.types.is_list(column_type) and pa.types.is_fixed_size_list( + column_type.value_type + ): + dim = column_type.value_type.list_size + else: + raise TypeError( + f"Query column {column} must be a vector. Got {column_field.type}." + ) + + if q_dim != dim: + raise ValueError( + f"Query vector size {len(q)} does not match index column size {dim}" + ) + + if k is not None and int(k) <= 0: + raise ValueError(f"Nearest-K must be > 0 but got {k}") + if nprobes is not None and int(nprobes) <= 0: + raise ValueError(f"Nprobes must be > 0 but got {nprobes}") + if minimum_nprobes is not None and int(minimum_nprobes) < 0: + raise ValueError(f"Minimum nprobes must be >= 0 but got {minimum_nprobes}") + if maximum_nprobes is not None and int(maximum_nprobes) < 0: + raise ValueError(f"Maximum nprobes must be >= 0 but got {maximum_nprobes}") + + if nprobes is not None: + if minimum_nprobes is not None or maximum_nprobes is not None: + raise ValueError( + "nprobes cannot be set in combination with minimum_nprobes or " + "maximum_nprobes" + ) + else: + minimum_nprobes = nprobes + maximum_nprobes = nprobes + if ( + minimum_nprobes is not None + and maximum_nprobes is not None + and minimum_nprobes > maximum_nprobes + ): + raise ValueError("minimum_nprobes must be <= maximum_nprobes") + if refine_factor is not None and int(refine_factor) < 1: + raise ValueError(f"Refine factor must be 1 or more got {refine_factor}") + if ef is not None and int(ef) <= 0: + # `ef` should be >= `k`, but `k` could be None so we can't check it here + # the rust code will check it + raise ValueError(f"ef must be > 0 but got {ef}") + + if distance_range is not None: + if len(distance_range) != 2: + raise ValueError( + "distance_range must be a tuple of (lower_bound, upper_bound)" + ) + + return { + "column": column, + "q": q, + "k": k, + "metric": metric, + "minimum_nprobes": minimum_nprobes, + "maximum_nprobes": maximum_nprobes, + "refine_factor": refine_factor, + "use_index": use_index, + "ef": ef, + "distance_range": distance_range, + } + + def _validate_schema(schema: pa.Schema): """ Make sure the metadata is valid utf8 @@ -6093,3 +6196,36 @@ def read_partition( return self.dataset._ds.read_index_partition( self.index_name, partition_id, with_vector ).read_all() + + +class VectorSearchQuery: + _inner: dict + + def __init__( + self, + column: str, + q: QueryVectorLike, + k: Optional[int] = None, + metric: Optional[str] = None, + nprobes: Optional[int] = None, + minimum_nprobes: Optional[int] = None, + maximum_nprobes: Optional[int] = None, + refine_factor: Optional[int] = None, + use_index: bool = True, + ef: Optional[int] = None, + ): + self._inner = _build_vector_search_query( + column, + q, + k=k, + metric=metric, + nprobes=nprobes, + minimum_nprobes=minimum_nprobes, + maximum_nprobes=maximum_nprobes, + refine_factor=refine_factor, + use_index=use_index, + ef=ef, + ) + + def inner(self): + return self._inner diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index d976bea8cf4..ebaf791d944 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -61,6 +61,7 @@ from .fragment import ( RowIdMeta as RowIdMeta, ) from .indices import IndexDescription as IndexDescription +from .lance import PySearchFilter from .optimize import ( Compaction as Compaction, ) @@ -223,6 +224,7 @@ class _Dataset: columns: Optional[List[str]] = None, columns_with_transform: Optional[List[Tuple[str, str]]] = None, filter: Optional[str] = None, + search_filter: Optional[PySearchFilter] = None, prefilter: Optional[bool] = None, limit: Optional[int] = None, offset: Optional[int] = None, diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 75ec01d9a82..3763f9a8cb2 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -4273,3 +4273,122 @@ def test_describe_indices(tmp_path): indices = ds.describe_indices() for index in indices: assert index.num_rows_indexed == 50 + + +def test_vector_filter_fts_search(tmp_path): + # Create test data + ids = list(range(1, 301)) + vectors = [[float(i)] * 4 for i in ids] + + # Create text data: + # "text " for ids 1-255, 299, 300, + # "noop " for 256-298, + texts = [] + for i in ids: + if i <= 255: + texts.append(f"text {i}") + elif i <= 298: + texts.append(f"noop {i}") + else: + texts.append(f"text {i}") + + categories = [] + for i in ids: + if i % 3 == 1: + categories.append("literature") + elif i % 3 == 2: + categories.append("science") + else: + categories.append("geography") + + table = pa.table( + { + "id": ids, + "vector": pa.array(vectors, type=pa.list_(pa.float32(), 4)), + "text": texts, + "category": categories, + } + ) + + # Write dataset and create indices + ds = lance.write_dataset(table, tmp_path) + + ds = ds.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=2, + num_sub_vectors=4, + ) + ds.create_scalar_index("text", index_type="INVERTED", with_position=True) + + # Create vector_query + vector_query = { + "column": "vector", + "q": np.array([300, 300, 300, 300], dtype=np.float32), + "k": 5, + "minimum_nprobes": 20, + "use_index": True, + } + + # Case 1: search with prefilter=true, query_filter=vector([300,300,300,300]) + scanner = ds.scanner( + prefilter=False, nearest=vector_query, filter=MatchQuery("text", "text") + ) + result = scanner.to_table() + assert [300, 299] == result["id"].to_pylist() + + # Case 2: search with prefilter=true, search_filter=match("text"), + # filter="category='geography'" + scanner = ds.scanner( + prefilter=True, + nearest=vector_query, + filter={ + "expr_filter": "category='geography'", + "search_filter": MatchQuery("text", "text"), + }, + ) + result = scanner.to_table() + assert [300, 255, 252, 249, 246] == result["id"].to_pylist() + + # Case 3: search with prefilter=false, search_filter=match("text") + scanner = ds.scanner( + prefilter=False, + nearest=vector_query, + filter=MatchQuery("text", "text"), + ) + result = scanner.to_table() + assert [300, 299] == result["id"].to_pylist() + + # Case 4: search with prefilter=false, search_filter=match("text"), + # filter="category='geography'" + scanner = ds.scanner( + prefilter=False, + nearest=vector_query, + filter={ + "expr_filter": "category='geography'", + "search_filter": MatchQuery("text", "text"), + }, + ) + result = scanner.to_table() + assert [300] == result["id"].to_pylist() + + # Case 5: search with prefilter=false, search_filter=phrase("text") + scanner = ds.scanner( + prefilter=False, + nearest=vector_query, + filter=PhraseQuery("text", "text"), + ) + result = scanner.to_table() + assert [299, 300] == result["id"].to_pylist() + + # Case 6: search with prefilter=false, search_filter=phrase("text") + scanner = ds.scanner( + prefilter=False, + nearest=vector_query, + filter={ + "expr_filter": "category='geography'", + "search_filter": PhraseQuery("text", "text"), + }, + ) + result = scanner.to_table() + assert [300] == result["id"].to_pylist() diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 71850135ae3..087338a9a21 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -21,6 +21,7 @@ from lance import LanceDataset, LanceFragment from lance.dataset import Index, VectorIndexReader from lance.indices import IndexFileVersion, IndicesBuilder +from lance.query import MatchQuery, PhraseQuery from lance.util import validate_vector_index # noqa: E402 from lance.vector import vec_to_table # noqa: E402 @@ -2747,3 +2748,128 @@ def collect_ids_and_distances(ds_with_index): assert ids_12 == ids_21 for a, b in zip(dists_12, dists_21): assert np.allclose(a, b, atol=1e-6) + + +def test_fts_filter_vector_search(tmp_path): + # Create dataset with vector and text columns + ids = list(range(1, 301)) + vectors = [[float(i)] * 4 for i in ids] + + # Create text data: + # "text " for ids 1-255, 299, 300, + # "noop " for 256-298, + texts = [] + for i in ids: + if i <= 255: + texts.append(f"text {i}") + elif i <= 298: + texts.append(f"noop {i}") + else: + texts.append(f"text {i}") + + categories = [] + for i in ids: + if i % 3 == 1: + categories.append("literature") + elif i % 3 == 2: + categories.append("science") + else: + categories.append("geography") + + table = pa.table( + { + "id": ids, + "vector": pa.array(vectors, type=pa.list_(pa.float32(), 4)), + "text": texts, + "category": categories, + } + ) + + # Write dataset and create indices + dataset = lance.write_dataset(table, tmp_path) + dataset = dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=2, + num_sub_vectors=4, + ) + dataset.create_scalar_index("text", index_type="INVERTED", with_position=True) + + query_vector = [300.0, 300.0, 300.0, 300.0] + + # Case 1: search with prefilter=true, query_filter=match("text") + scanner = dataset.scanner( + filter=MatchQuery("text", "text"), + nearest={"column": "vector", "q": query_vector, "k": 5}, + prefilter=True, + ) + + result = scanner.to_table() + ids_result = result["id"].to_pylist() + assert [300, 299, 255, 254, 253] == ids_result + + # Case 2: search with prefilter=true, search_filter=match("text"), + # filter="category='geography'" + scanner = dataset.scanner( + nearest={"column": "vector", "q": query_vector, "k": 5}, + prefilter=True, + filter={ + "expr_filter": "category='geography'", + "search_filter": MatchQuery("text", "text"), + }, + ) + + result = scanner.to_table() + ids_result = result["id"].to_pylist() + assert [300, 255, 252, 249, 246] == ids_result + + # Case 3: search with prefilter=false, search_filter=match("text") + scanner = dataset.scanner( + filter=MatchQuery("text", "text"), + nearest={"column": "vector", "q": query_vector, "k": 5}, + prefilter=False, + ) + + result = scanner.to_table() + ids_result = result["id"].to_pylist() + assert [300, 299] == ids_result + + # Case 4: search with prefilter=false, search_filter=match("text"), + # filter="category='geography'" + scanner = dataset.scanner( + nearest={"column": "vector", "q": query_vector, "k": 5}, + prefilter=False, + filter={ + "expr_filter": "category='geography'", + "search_filter": MatchQuery("text", "text"), + }, + ) + + result = scanner.to_table() + ids_result = result["id"].to_pylist() + assert [300] == ids_result + + # Case 5: search with prefilter=false, search_filter=phrase("text") + scanner = dataset.scanner( + nearest={"column": "vector", "q": query_vector, "k": 5}, + prefilter=False, + filter=PhraseQuery("text", "text"), + ) + + result = scanner.to_table() + ids_result = result["id"].to_pylist() + assert [299, 300] == ids_result + + # Case 6: search with prefilter=false, search_filter=phrase("text") + scanner = dataset.scanner( + nearest={"column": "vector", "q": query_vector, "k": 5}, + prefilter=False, + filter={ + "expr_filter": "category='geography'", + "search_filter": PhraseQuery("text", "text"), + }, + ) + + result = scanner.to_table() + ids_result = result["id"].to_pylist() + assert [300] == ids_result diff --git a/python/src/dataset.rs b/python/src/dataset.rs index b4b4d8af050..588a9508322 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -37,6 +37,7 @@ use lance::dataset::cleanup::CleanupPolicyBuilder; use lance::dataset::refs::{Ref, TagContents}; use lance::dataset::scanner::{ ColumnOrdering, DatasetRecordBatchStream, ExecutionStatsCallback, MaterializationStyle, + QueryFilter, }; use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt}; use lance::dataset::AutoCleanupParams; @@ -74,7 +75,7 @@ use lance_index::{ scalar::{FullTextSearchQuery, InvertedIndexParams, ScalarIndexParams}, vector::{ hnsw::builder::HnswBuildParams, ivf::IvfBuildParams, pq::PQBuildParams, - sq::builder::SQBuildParams, + sq::builder::SQBuildParams, Query as VectorQuery, }, DatasetIndexExt, IndexParams, IndexType, }; @@ -764,12 +765,13 @@ impl Dataset { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature=(columns=None, columns_with_transform=None, filter=None, prefilter=None, limit=None, offset=None, nearest=None, batch_size=None, io_buffer_size=None, batch_readahead=None, fragment_readahead=None, scan_in_order=None, fragments=None, with_row_id=None, with_row_address=None, use_stats=None, substrait_filter=None, fast_search=None, full_text_query=None, late_materialization=None, blob_handling=None, use_scalar_index=None, include_deleted_rows=None, scan_stats_callback=None, strict_batch_size=None, order_by=None, disable_scoring_autoprojection=None))] + #[pyo3(signature=(columns=None, columns_with_transform=None, filter=None, search_filter=None, prefilter=None, limit=None, offset=None, nearest=None, batch_size=None, io_buffer_size=None, batch_readahead=None, fragment_readahead=None, scan_in_order=None, fragments=None, with_row_id=None, with_row_address=None, use_stats=None, substrait_filter=None, fast_search=None, full_text_query=None, late_materialization=None, blob_handling=None, use_scalar_index=None, include_deleted_rows=None, scan_stats_callback=None, strict_batch_size=None, order_by=None, disable_scoring_autoprojection=None))] fn scanner( self_: PyRef<'_, Self>, columns: Option>, columns_with_transform: Option>, filter: Option, + search_filter: Option, prefilter: Option, limit: Option, offset: Option, @@ -837,6 +839,11 @@ impl Dataset { .filter(f.as_str()) .map_err(|err| PyValueError::new_err(err.to_string()))?; } + if let Some(qf) = search_filter { + scanner + .filter_query(qf.inner) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + } if let Some(full_text_query) = full_text_query { let fts_query = if let Ok(full_text_query) = full_text_query.downcast::() { let mut query = full_text_query @@ -1001,111 +1008,18 @@ impl Dataset { } if let Some(nearest) = nearest { - let column = nearest - .get_item("column")? - .ok_or_else(|| PyKeyError::new_err("Need column for nearest"))? - .to_string(); - - let qval = nearest - .get_item("q")? - .ok_or_else(|| PyKeyError::new_err("Need q for nearest"))?; - let data = ArrayData::from_pyarrow_bound(&qval)?; - let q = make_array(data); - - let k: usize = if let Some(k) = nearest.get_item("k")? { - if k.is_none() { - // Use limit if k is not specified, default to 10. - limit.unwrap_or(10) as usize - } else { - k.extract()? - } - } else { - 10 - }; - - let mut minimum_nprobes = DEFAULT_NPROBES; - let mut maximum_nprobes = None; - - if let Some(nprobes) = nearest.get_item("nprobes")? { - if !nprobes.is_none() { - let extracted: usize = nprobes.extract()?; - minimum_nprobes = extracted; - maximum_nprobes = Some(extracted); - } - } - - if let Some(min_nprobes) = nearest.get_item("minimum_nprobes")? { - if !min_nprobes.is_none() { - minimum_nprobes = min_nprobes.extract()?; - } - } - - if let Some(max_nprobes) = nearest.get_item("maximum_nprobes")? { - if !max_nprobes.is_none() { - maximum_nprobes = Some(max_nprobes.extract()?); - } - } - - if let Some(maximum_nprobes) = maximum_nprobes { - if minimum_nprobes > maximum_nprobes { - return Err(PyValueError::new_err( - "minimum_nprobes must be <= maximum_nprobes", - )); - } - } - - if minimum_nprobes < 1 { - return Err(PyValueError::new_err("minimum_nprobes must be >= 1")); - } - - if let Some(maximum_nprobes) = maximum_nprobes { - if maximum_nprobes < 1 { - return Err(PyValueError::new_err("maximum_nprobes must be >= 1")); - } - } - - let metric_type: Option = - if let Some(metric) = nearest.get_item("metric")? { - if metric.is_none() { - None - } else { - Some( - MetricType::try_from(metric.to_string().to_lowercase().as_str()) - .map_err(|err| PyValueError::new_err(err.to_string()))?, - ) - } - } else { - None - }; - - // When refine factor is specified, a final Refine stage will be added to the I/O plan, - // and use Flat index over the raw vectors to refine the results. - // By default, `refine_factor` is None to not involve extra I/O exec node and random access. - let refine_factor: Option = if let Some(rf) = nearest.get_item("refine_factor")? { - if rf.is_none() { - None - } else { - rf.extract()? - } - } else { - None - }; - - let use_index: bool = if let Some(idx) = nearest.get_item("use_index")? { - idx.extract()? - } else { - true - }; - - let ef: Option = if let Some(ef) = nearest.get_item("ef")? { - if ef.is_none() { - None - } else { - ef.extract()? - } - } else { - None - }; + let default_k: usize = limit.unwrap_or(10) as usize; + let ( + column, + q, + k, + minimum_nprobes, + maximum_nprobes, + metric_type, + refine_factor, + use_index, + ef, + ) = vector_query_params_from_dict(nearest, default_k)?; let (_, element_type) = get_vector_type(self_.ds.schema(), &column) .map_err(|e| PyValueError::new_err(e.to_string()))?; @@ -3648,3 +3562,194 @@ impl PyFullTextQuery { }) } } + +type VectorQueryParams = ( + String, + arrow_array::ArrayRef, + usize, + usize, + Option, + Option, + Option, + bool, + Option, +); + +fn vector_query_params_from_dict( + dict: &Bound<'_, PyDict>, + default_k: usize, +) -> PyResult { + let column = dict + .get_item("column")? + .ok_or_else(|| PyKeyError::new_err("Need column for nearest"))? + .to_string(); + + let qval = dict + .get_item("q")? + .ok_or_else(|| PyKeyError::new_err("Need q for nearest"))?; + let data = ArrayData::from_pyarrow_bound(&qval)?; + let key = make_array(data); + + let k: usize = if let Some(k) = dict.get_item("k")? { + if k.is_none() { + // Use limit if k is not specified, default to 10. + default_k + } else { + k.extract()? + } + } else { + default_k + }; + + let mut minimum_nprobes = DEFAULT_NPROBES; + let mut maximum_nprobes: Option = None; + + if let Some(nprobes) = dict.get_item("nprobes")? { + if !nprobes.is_none() { + let extracted: usize = nprobes.extract()?; + minimum_nprobes = extracted; + maximum_nprobes = Some(extracted); + } + } + + if let Some(min_nprobes) = dict.get_item("minimum_nprobes")? { + if !min_nprobes.is_none() { + minimum_nprobes = min_nprobes.extract()?; + } + } + + if let Some(max_nprobes) = dict.get_item("maximum_nprobes")? { + if !max_nprobes.is_none() { + maximum_nprobes = Some(max_nprobes.extract()?); + } + } + + if let Some(maximum_nprobes_val) = maximum_nprobes { + if minimum_nprobes > maximum_nprobes_val { + return Err(PyValueError::new_err( + "minimum_nprobes must be <= maximum_nprobes", + )); + } + } + + if minimum_nprobes < 1 { + return Err(PyValueError::new_err("minimum_nprobes must be >= 1")); + } + + if let Some(maximum_nprobes_val) = maximum_nprobes { + if maximum_nprobes_val < 1 { + return Err(PyValueError::new_err("maximum_nprobes must be >= 1")); + } + } + + let metric_type: Option = if let Some(metric) = dict.get_item("metric")? { + if metric.is_none() { + None + } else { + Some( + MetricType::try_from(metric.to_string().to_lowercase().as_str()) + .map_err(|err| PyValueError::new_err(err.to_string()))?, + ) + } + } else { + None + }; + + // When refine factor is specified, a final Refine stage will be added to the I/O plan, + // and use Flat index over the raw vectors to refine the results. + // By default, `refine_factor` is None to not involve extra I/O exec node and random access. + let refine_factor: Option = if let Some(rf) = dict.get_item("refine_factor")? { + if rf.is_none() { + None + } else { + rf.extract()? + } + } else { + None + }; + + let use_index: bool = if let Some(idx) = dict.get_item("use_index")? { + idx.extract()? + } else { + true + }; + + let ef: Option = if let Some(ef_obj) = dict.get_item("ef")? { + if ef_obj.is_none() { + None + } else { + ef_obj.extract()? + } + } else { + None + }; + + Ok(( + column, + key, + k, + minimum_nprobes, + maximum_nprobes, + metric_type, + refine_factor, + use_index, + ef, + )) +} + +#[pyclass(name = "PySearchFilter")] +#[derive(Debug, Clone)] +pub struct PySearchFilter { + pub(crate) inner: QueryFilter, +} + +#[pymethods] +impl PySearchFilter { + /// Create a search filter from a full text query. + #[staticmethod] + #[pyo3(signature = (query))] + fn from_full_text_query(query: PyFullTextQuery) -> PyResult { + Ok(Self { + inner: QueryFilter::Fts(FullTextSearchQuery::new_query(query.inner.clone())), + }) + } + + /// Create a query filter from a vector search query dict. + #[staticmethod] + #[pyo3(signature = (query))] + fn from_vector_search_query(query: &Bound<'_, PyDict>) -> PyResult { + let default_k = 10; + let ( + column, + key, + k, + minimum_nprobes, + maximum_nprobes, + metric_type_opt, + refine_factor, + use_index, + ef, + ) = vector_query_params_from_dict(query, default_k)?; + + let metric_type = Some(metric_type_opt.unwrap_or(MetricType::L2)); + + let vector_query = VectorQuery { + column, + key, + k, + lower_bound: None, + upper_bound: None, + minimum_nprobes, + maximum_nprobes, + ef, + refine_factor, + metric_type, + use_index, + dist_q_c: 0.0, + }; + + Ok(Self { + inner: QueryFilter::Vector(vector_query), + }) + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index 5243e0fddd4..b1b8ce66599 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -43,7 +43,7 @@ use dataset::io_stats::IoStats; use dataset::optimize::{ PyCompaction, PyCompactionMetrics, PyCompactionPlan, PyCompactionTask, PyRewriteResult, }; -use dataset::{DatasetBasePath, MergeInsertBuilder, PyFullTextQuery}; +use dataset::{DatasetBasePath, MergeInsertBuilder, PyFullTextQuery, PySearchFilter}; use env_logger::{Builder, Env}; use file::{ stable_version, LanceBufferDescriptor, LanceColumnMetadata, LanceFileMetadata, LanceFileReader, @@ -276,6 +276,7 @@ fn lance(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;