diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 836477c2ce3..0a3253bc2dd 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -167,6 +167,7 @@ def __init__( ): uri = os.fspath(uri) if isinstance(uri, Path) else uri self._uri = uri + self._storage_options = storage_options self._ds = _Dataset( uri, version, @@ -183,6 +184,7 @@ def __init__( def __deserialize__( cls, uri: str, + storage_options: Optional[Dict[str, str]], version: int, manifest: bytes, default_scan_options: Optional[Dict[str, Any]], @@ -190,6 +192,7 @@ def __deserialize__( return cls( uri, version, + storage_options=storage_options, serialized_manifest=manifest, default_scan_options=default_scan_options, ) @@ -197,6 +200,7 @@ def __deserialize__( def __reduce__(self): return type(self).__deserialize__, ( self.uri, + self._storage_options, self._ds.version(), self._ds.serialized_manifest(), self._default_scan_options, @@ -205,16 +209,20 @@ def __reduce__(self): def __getstate__(self): return ( self.uri, + self._storage_options, self._ds.version(), self._ds.serialized_manifest(), self._default_scan_options, ) def __setstate__(self, state): - self._uri, version, manifest, default_scan_options = state + self._uri, self._storage_options, version, manifest, default_scan_options = ( + state + ) self._ds = _Dataset( self._uri, version, + storage_options=self._storage_options, manifest=manifest, default_scan_options=default_scan_options, ) @@ -222,6 +230,7 @@ def __setstate__(self, state): def __copy__(self): ds = LanceDataset.__new__(LanceDataset) ds._uri = self._uri + ds._storage_options = self._storage_options ds._ds = copy.copy(self._ds) ds._default_scan_options = self._default_scan_options return ds @@ -2208,6 +2217,7 @@ def commit( max_retries=max_retries, ) ds = LanceDataset.__new__(LanceDataset) + ds._storage_options = storage_options ds._ds = new_ds ds._uri = new_ds.uri ds._default_scan_options = None @@ -3495,6 +3505,7 @@ def write_dataset( inner_ds = _write_dataset(reader, uri, params) ds = LanceDataset.__new__(LanceDataset) + ds._storage_options = storage_options ds._ds = inner_ds ds._uri = inner_ds.uri ds._default_scan_options = None diff --git a/python/python/tests/test_s3_ddb.py b/python/python/tests/test_s3_ddb.py index c2073aee74e..adabc740e45 100644 --- a/python/python/tests/test_s3_ddb.py +++ b/python/python/tests/test_s3_ddb.py @@ -289,6 +289,18 @@ def test_file_writer_reader(s3_bucket: str): ) +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.integration +@pytest.mark.skipif(not _RAY_AVAILABLE, reason="ray is not available") +def test_ray_read_lance(s3_bucket: str): + storage_options = copy.deepcopy(CONFIG) + table = pa.table({"a": [1, 2], "b": ["a", "b"]}) + path = f"s3://{s3_bucket}/test_ray_read.lance" + lance.write_dataset(table, path, storage_options=storage_options) + ds = ray.data.read_lance(path, storage_options=storage_options, concurrency=1) + ds.take(1) + + @pytest.mark.integration def test_append_fragment(s3_bucket: str): storage_options = copy.deepcopy(CONFIG)