Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions python/python/lance/indices/ivf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors

from typing import Dict, Optional

import pyarrow as pa

from lance.file import LanceFileReader, LanceFileWriter
Expand All @@ -24,7 +26,7 @@ def num_partitions(self) -> int:
"""
return len(self.centroids)

def save(self, uri: str):
def save(self, uri: str, *, storage_options: Optional[Dict[str, str]] = None):
"""
Save the IVF model to a lance file.

Expand All @@ -34,19 +36,22 @@ def save(self, uri: str):
uri: str
The URI to save the model to. The URI can be a local file path or a
cloud storage path.
storage_options : optional, dict
Extra options for the storage backend (e.g. S3 credentials).
"""
with LanceFileWriter(
uri,
pa.schema(
[pa.field("centroids", self.centroids.type)],
metadata={b"distance_type": self.distance_type.encode()},
),
storage_options=storage_options,
) as writer:
batch = pa.table([self.centroids], names=["centroids"])
writer.write_batch(batch)

@classmethod
def load(cls, uri: str):
def load(cls, uri: str, *, storage_options: Optional[Dict[str, str]] = None):
"""
Load an IVF model from a lance file.

Expand All @@ -56,8 +61,10 @@ def load(cls, uri: str):
uri: str
The URI to load the model from. The URI can be a local file path or a
cloud storage path.
storage_options : optional, dict
Extra options for the storage backend (e.g. S3 credentials).
"""
reader = LanceFileReader(uri)
reader = LanceFileReader(uri, storage_options=storage_options)
num_rows = reader.metadata().num_rows
metadata = reader.metadata().schema.metadata
distance_type = metadata[b"distance_type"].decode()
Expand Down
13 changes: 10 additions & 3 deletions python/python/lance/indices/pq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors

from typing import Dict, Optional

import pyarrow as pa

from lance.file import LanceFileReader, LanceFileWriter
Expand All @@ -23,7 +25,7 @@ def dimension(self):
"""The dimension of the vectors this model was trained on"""
return self.codebook.type.list_size

def save(self, uri: str):
def save(self, uri: str, *, storage_options: Optional[Dict[str, str]] = None):
"""
Save the PQ model to a lance file.

Expand All @@ -33,19 +35,22 @@ def save(self, uri: str):
uri: str
The URI to save the model to. The URI can be a local file path or a
cloud storage path.
storage_options : optional, dict
Extra options for the storage backend (e.g. S3 credentials).
"""
with LanceFileWriter(
uri,
pa.schema(
[pa.field("codebook", self.codebook.type)],
metadata={b"num_subvectors": str(self.num_subvectors).encode()},
),
storage_options=storage_options,
) as writer:
batch = pa.table([self.codebook], names=["codebook"])
writer.write_batch(batch)

@classmethod
def load(cls, uri: str):
def load(cls, uri: str, *, storage_options: Optional[Dict[str, str]] = None):
"""
Load a PQ model from a lance file.

Expand All @@ -55,8 +60,10 @@ def load(cls, uri: str):
uri: str
The URI to load the model from. The URI can be a local file path or a
cloud storage path.
storage_options : optional, dict
Extra options for the storage backend (e.g. S3 credentials).
"""
reader = LanceFileReader(uri)
reader = LanceFileReader(uri, storage_options=storage_options)
num_rows = reader.metadata().num_rows
metadata = reader.metadata().schema.metadata
num_subvectors = int(metadata[b"num_subvectors"].decode())
Expand Down
Loading