From 7b0f14f21c74297253db63039e309bc7299fdb98 Mon Sep 17 00:00:00 2001 From: hushengquan <1390305506@qq.com> Date: Fri, 27 Mar 2026 19:30:17 +0800 Subject: [PATCH] feat(python): Add storage_options to IvfModel and PqModel save/load --- python/python/lance/indices/ivf.py | 13 ++++++++++--- python/python/lance/indices/pq.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/python/python/lance/indices/ivf.py b/python/python/lance/indices/ivf.py index fa92f744d55..fef19dde73a 100644 --- a/python/python/lance/indices/ivf.py +++ b/python/python/lance/indices/ivf.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors +from typing import Dict, Optional + import pyarrow as pa from lance.file import LanceFileReader, LanceFileWriter @@ -24,7 +26,7 @@ def num_partitions(self) -> int: """ return len(self.centroids) - def save(self, uri: str): + def save(self, uri: str, *, storage_options: Optional[Dict[str, str]] = None): """ Save the IVF model to a lance file. @@ -34,6 +36,8 @@ def save(self, uri: str): uri: str The URI to save the model to. The URI can be a local file path or a cloud storage path. + storage_options : optional, dict + Extra options for the storage backend (e.g. S3 credentials). """ with LanceFileWriter( uri, @@ -41,12 +45,13 @@ def save(self, uri: str): [pa.field("centroids", self.centroids.type)], metadata={b"distance_type": self.distance_type.encode()}, ), + storage_options=storage_options, ) as writer: batch = pa.table([self.centroids], names=["centroids"]) writer.write_batch(batch) @classmethod - def load(cls, uri: str): + def load(cls, uri: str, *, storage_options: Optional[Dict[str, str]] = None): """ Load an IVF model from a lance file. @@ -56,8 +61,10 @@ def load(cls, uri: str): uri: str The URI to load the model from. The URI can be a local file path or a cloud storage path. + storage_options : optional, dict + Extra options for the storage backend (e.g. S3 credentials). """ - reader = LanceFileReader(uri) + reader = LanceFileReader(uri, storage_options=storage_options) num_rows = reader.metadata().num_rows metadata = reader.metadata().schema.metadata distance_type = metadata[b"distance_type"].decode() diff --git a/python/python/lance/indices/pq.py b/python/python/lance/indices/pq.py index 09f34f04dfe..b3aeb50bcbe 100644 --- a/python/python/lance/indices/pq.py +++ b/python/python/lance/indices/pq.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors +from typing import Dict, Optional + import pyarrow as pa from lance.file import LanceFileReader, LanceFileWriter @@ -23,7 +25,7 @@ def dimension(self): """The dimension of the vectors this model was trained on""" return self.codebook.type.list_size - def save(self, uri: str): + def save(self, uri: str, *, storage_options: Optional[Dict[str, str]] = None): """ Save the PQ model to a lance file. @@ -33,6 +35,8 @@ def save(self, uri: str): uri: str The URI to save the model to. The URI can be a local file path or a cloud storage path. + storage_options : optional, dict + Extra options for the storage backend (e.g. S3 credentials). """ with LanceFileWriter( uri, @@ -40,12 +44,13 @@ def save(self, uri: str): [pa.field("codebook", self.codebook.type)], metadata={b"num_subvectors": str(self.num_subvectors).encode()}, ), + storage_options=storage_options, ) as writer: batch = pa.table([self.codebook], names=["codebook"]) writer.write_batch(batch) @classmethod - def load(cls, uri: str): + def load(cls, uri: str, *, storage_options: Optional[Dict[str, str]] = None): """ Load a PQ model from a lance file. @@ -55,8 +60,10 @@ def load(cls, uri: str): uri: str The URI to load the model from. The URI can be a local file path or a cloud storage path. + storage_options : optional, dict + Extra options for the storage backend (e.g. S3 credentials). """ - reader = LanceFileReader(uri) + reader = LanceFileReader(uri, storage_options=storage_options) num_rows = reader.metadata().num_rows metadata = reader.metadata().schema.metadata num_subvectors = int(metadata[b"num_subvectors"].decode())