Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/run_tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ inputs:
runs:
using: "composite"
steps:
- name: Setup MSVC for torch.compile
if: runner.os == 'Windows'
uses: ilammy/msvc-dev-cmd@v1
- name: Install dependencies
working-directory: python
shell: bash
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ jobs:
# Need up-to-date compilers for kernels
CC: clang
CXX: clang++
# Treat warnings as errors to catch issues early
RUSTFLAGS: "-D warnings"
steps:
- uses: actions/checkout@v4
# pin the toolchain version to avoid surprises
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ target

python/venv
test_data/venv
.venv

**/*.profraw
*.lance
Expand Down
8 changes: 6 additions & 2 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ tests = [
]
dev = ["ruff==0.4.1", "pyright"]
benchmarks = ["pytest-benchmark"]
torch = ["torch"]
torch = ["torch>=2.0"]
geo = [
"geoarrow-rust-core",
"geoarrow-rust-io",
Expand Down Expand Up @@ -115,9 +115,13 @@ filterwarnings = [
'ignore:.*datetime\.datetime\.utcnow\(\) is deprecated.*:DeprecationWarning',
# Pandas 2.2 on Python 2.12
'ignore:.*datetime\.datetime\.utcfromtimestamp\(\) is deprecated.*:DeprecationWarning',
# Pytorch 2.2 on Python 2.12
# Pytorch 2.2 on Python 3.12
'ignore:.*is deprecated and will be removed in Python 3\.14.*:DeprecationWarning',
'ignore:.*The distutils package is deprecated.*:DeprecationWarning',
# Pytorch inductor uses deprecated load_module() in its code cache
'ignore:.*the load_module\(\) method is deprecated.*:DeprecationWarning',
# Pytorch uses deprecated jit.script_method internally (torch/utils/mkldnn.py)
'ignore:.*torch\.jit\.script_method.*is deprecated.*:DeprecationWarning',
# TensorFlow/Keras import can emit NumPy deprecation FutureWarnings in some environments.
# Keep FutureWarnings as errors generally, but ignore this known-noisy import-time warning.
'ignore:.*np\.object.*:FutureWarning',
Expand Down
43 changes: 38 additions & 5 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,9 @@ def scanner(
fast_search: Optional[bool] = None,
io_buffer_size: Optional[int] = None,
late_materialization: Optional[bool | List[str]] = None,
blob_handling: Optional[
Literal["all_binary", "blobs_descriptions", "all_descriptions"]
] = None,
use_scalar_index: Optional[bool] = None,
include_deleted_rows: Optional[bool] = None,
scan_stats_callback: Optional[Callable[[ScanStatistics], None]] = None,
Expand Down Expand Up @@ -780,6 +783,12 @@ def scanner(
of the rows. If your filter is more selective (e.g. find by id) you may
want to set this to True. If your filter is not very selective (e.g.
matches 20% of the rows) you may want to set this to False.
blob_handling: str, default None
Controls how blob columns are returned.

- "all_binary": read blob columns as binary / large_binary values
- "blobs_descriptions": read blob columns as descriptions (default)
- "all_descriptions": read all binary columns as descriptions
full_text_query: str or dict, optional
query string to search for, the results will be ranked by BM25.
e.g. "hello world", would match documents containing "hello" or "world".
Expand Down Expand Up @@ -870,6 +879,7 @@ def setopt(opt, val):
setopt(builder.scan_in_order, scan_in_order)
setopt(builder.with_fragments, fragments)
setopt(builder.late_materialization, late_materialization)
setopt(builder.blob_handling, blob_handling)
setopt(builder.with_row_id, with_row_id)
setopt(builder.with_row_address, with_row_address)
setopt(builder.use_stats, use_stats)
Expand Down Expand Up @@ -958,6 +968,7 @@ def to_table(
full_text_query: Optional[Union[str, dict, FullTextQuery]] = None,
io_buffer_size: Optional[int] = None,
late_materialization: Optional[bool | List[str]] = None,
blob_handling: Optional[str] = None,
use_scalar_index: Optional[bool] = None,
include_deleted_rows: Optional[bool] = None,
order_by: Optional[List[ColumnOrdering]] = None,
Expand Down Expand Up @@ -1012,6 +1023,9 @@ def to_table(
late_materialization: bool or List[str], default None
Allows custom control over late materialization. See
``ScannerBuilder.late_materialization`` for more information.
blob_handling: str, default None
Controls how blob columns are returned. See ``LanceDataset.scanner`` for
details.
use_scalar_index: bool, default True
Allows custom control over scalar index usage. See
``ScannerBuilder.use_scalar_index`` for more information.
Expand Down Expand Up @@ -1073,6 +1087,7 @@ def to_table(
batch_readahead=batch_readahead,
fragment_readahead=fragment_readahead,
late_materialization=late_materialization,
blob_handling=blob_handling,
use_scalar_index=use_scalar_index,
scan_in_order=scan_in_order,
prefilter=prefilter,
Expand Down Expand Up @@ -1455,6 +1470,7 @@ def to_batches(
full_text_query: Optional[Union[str, dict]] = None,
io_buffer_size: Optional[int] = None,
late_materialization: Optional[bool | List[str]] = None,
blob_handling: Optional[str] = None,
use_scalar_index: Optional[bool] = None,
strict_batch_size: Optional[bool] = None,
order_by: Optional[List[ColumnOrdering]] = None,
Expand Down Expand Up @@ -1483,6 +1499,7 @@ def to_batches(
batch_readahead=batch_readahead,
fragment_readahead=fragment_readahead,
late_materialization=late_materialization,
blob_handling=blob_handling,
use_scalar_index=use_scalar_index,
scan_in_order=scan_in_order,
prefilter=prefilter,
Expand Down Expand Up @@ -2138,11 +2155,11 @@ def merge_insert(
... .execute(new_table)
{'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0}
>>> dataset.to_table().sort_by("a").to_pandas()
a b c
0 1 a x
1 2 x y
2 3 y z
3 4 z None
a b c
0 1 a x
1 2 x y
2 3 y z
3 4 z NaN
"""
return MergeInsertBuilder(self._ds, on)

Expand Down Expand Up @@ -4581,6 +4598,7 @@ def __init__(self, ds: LanceDataset):
self._substrait_filter = None
self._prefilter = False
self._late_materialization = None
self._blob_handling = None
self._offset = None
self._columns = None
self._columns_with_transform = None
Expand Down Expand Up @@ -4776,6 +4794,20 @@ def late_materialization(
self._late_materialization = late_materialization
return self

def blob_handling(self, blob_handling: Optional[str]) -> ScannerBuilder:
if blob_handling is None:
self._blob_handling = None
return self

allowed = {"all_binary", "blobs_descriptions", "all_descriptions"}
if blob_handling not in allowed:
raise ValueError(
f"Invalid blob_handling: {blob_handling}. Expected one of: "
+ ", ".join(sorted(allowed))
)
self._blob_handling = blob_handling
return self

def use_stats(self, use_stats: bool = True) -> ScannerBuilder:
"""
Enable use of statistics for query planning.
Expand Down Expand Up @@ -5053,6 +5085,7 @@ def to_scanner(self) -> LanceScanner:
self._fast_search,
self._full_text_query,
self._late_materialization,
self._blob_handling,
self._use_scalar_index,
self._include_deleted_rows,
self._scan_stats_callback,
Expand Down
1 change: 1 addition & 0 deletions python/python/lance/lance/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class _Dataset:
fast_search: Optional[bool] = None,
full_text_query: Optional[dict] = None,
late_materialization: Optional[bool | List[str]] = None,
blob_handling: Optional[str] = None,
use_scalar_index: Optional[bool] = None,
include_deleted_rows: Optional[bool] = None,
) -> _Scanner: ...
Expand Down
12 changes: 6 additions & 6 deletions python/python/lance/torch/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
]


@torch.jit.script
@torch.compile
def _pairwise_cosine(
x: torch.Tensor, y: torch.Tensor, y2: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -49,7 +49,7 @@ def pairwise_cosine(
return _pairwise_cosine(x, y, y2)


@torch.jit.script
@torch.compile
def _cosine_distance(
vectors: torch.Tensor, centroids: torch.Tensor, split_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -114,7 +114,7 @@ def cosine_distance(
raise RuntimeError("Cosine distance out of memory")


@torch.jit.script
@torch.compile
def argmin_l2(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = x.reshape(1, x.shape[0], -1)
y = y.reshape(1, y.shape[0], -1)
Expand All @@ -125,7 +125,7 @@ def argmin_l2(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten
return min_dists.pow(2), idx


@torch.jit.script
@torch.compile
def pairwise_l2(
x: torch.Tensor, y: torch.Tensor, y2: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -170,7 +170,7 @@ def pairwise_l2(
return dists.type(origin_dtype)


@torch.jit.script
@torch.compile
def _l2_distance(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -237,7 +237,7 @@ def l2_distance(
raise RuntimeError("L2 distance out of memory")


@torch.jit.script
@torch.compile
def dot_distance(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pair-wise dot distance between two 2-D Tensors.

Expand Down
19 changes: 19 additions & 0 deletions python/python/tests/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ def test_blob_descriptions(tmp_path):
assert descriptions.field(1) == expected_sizes


def test_scan_blob_as_binary(tmp_path):
values = [b"foo", b"bar", b"baz"]
arr = pa.array(values, pa.large_binary())
table = pa.table(
[arr],
schema=pa.schema(
[
pa.field(
"blobs", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
)
]
),
)
ds = lance.write_dataset(table, tmp_path / "test_ds")

tbl = ds.scanner(columns=["blobs"], blob_handling="all_binary").to_table()
assert tbl.column("blobs").to_pylist() == values


@pytest.fixture
def dataset_with_blobs(tmp_path):
values = pa.array([b"foo", b"bar", b"baz"], pa.large_binary())
Expand Down
4 changes: 2 additions & 2 deletions python/python/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,12 @@ def test_duckdb(tmp_path):
expected = expected[(expected.price > 20.0) & (expected.price <= 90)].reset_index(
drop=True
)
tm.assert_frame_equal(actual, expected)
tm.assert_frame_equal(actual, expected, check_dtype=False)

actual = duckdb.query("SELECT id, meta, price FROM ds WHERE meta=='aa'").to_df()
expected = duckdb.query("SELECT id, meta, price FROM ds").to_df()
expected = expected[expected.meta == "aa"].reset_index(drop=True)
tm.assert_frame_equal(actual, expected)
tm.assert_frame_equal(actual, expected, check_dtype=False)


def test_struct_field_order(tmp_path):
Expand Down
49 changes: 23 additions & 26 deletions python/python/tests/test_scalar_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2403,10 +2403,23 @@ def compare_fts_results(
single_df = single_machine_results.to_pandas()
distributed_df = distributed_results.to_pandas()

# Sort both by row_id to ensure consistent ordering
if "_rowid" in single_df.columns:
single_df = single_df.sort_values("_rowid").reset_index(drop=True)
distributed_df = distributed_df.sort_values("_rowid").reset_index(drop=True)
# Normalize row ordering for comparisons.
#
# FTS search results do not guarantee a stable order for tied scores and
# different execution modes (single-machine vs distributed) may return rows
# in different (but equivalent) orders.
sort_cols = (
["_rowid"]
if "_rowid" in single_df.columns
else [c for c in single_df.columns if c != "_score"]
)
if sort_cols:
single_df = single_df.sort_values(sort_cols, kind="mergesort").reset_index(
drop=True
)
distributed_df = distributed_df.sort_values(
sort_cols, kind="mergesort"
).reset_index(drop=True)

# Compare row IDs (most important)
if "_rowid" in single_df.columns:
Expand All @@ -2418,8 +2431,8 @@ def compare_fts_results(

# Compare scores with tolerance
if "_score" in single_df.columns:
single_scores = single_df["_score"].values
distributed_scores = distributed_df["_score"].values
single_scores = single_df["_score"].to_numpy(dtype=float)
distributed_scores = distributed_df["_score"].to_numpy(dtype=float)
score_diff = np.abs(single_scores - distributed_scores)
max_diff = np.max(score_diff)
assert max_diff <= tolerance, (
Expand All @@ -2430,27 +2443,11 @@ def compare_fts_results(
# Compare other columns (exact match for non-score columns)
for col in single_df.columns:
if col not in ["_score"]: # Skip score column (already compared with tolerance)
single_values = (
set(single_df[col])
if single_df[col].dtype == "object"
else single_df[col].values
np.testing.assert_array_equal(
single_df[col].to_numpy(dtype=object),
distributed_df[col].to_numpy(dtype=object),
err_msg=f"Column {col} values don't match",
)
distributed_values = (
set(distributed_df[col])
if distributed_df[col].dtype == "object"
else distributed_df[col].values
)

if isinstance(single_values, set):
assert single_values == distributed_values, (
f"Column {col} content mismatch"
)
else:
np.testing.assert_array_equal(
single_values,
distributed_values,
err_msg=f"Column {col} values don't match",
)

return True

Expand Down
Loading
Loading