diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 50a7c9a9fb2..d9a96a86c85 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -34,7 +34,6 @@ import pyarrow as pa import pyarrow.dataset -from lance_namespace import DescribeTableRequest, LanceNamespace from pyarrow import RecordBatch, Schema from lance.log import LOGGER @@ -71,9 +70,11 @@ from .util import _target_partition_size_to_num_partitions, td_to_micros if TYPE_CHECKING: + from lance_namespace import LanceNamespace from pyarrow._compute import Expression from .commit import CommitLock + from .io import StorageOptionsProvider from .progress import FragmentWriteProgress from .types import ReaderLike @@ -3062,6 +3063,7 @@ def commit( read_version: Optional[int] = None, commit_lock: Optional[CommitLock] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider: Optional["StorageOptionsProvider"] = None, enable_v2_manifest_paths: Optional[bool] = None, detached: Optional[bool] = False, max_retries: int = 20, @@ -3106,6 +3108,8 @@ def commit( storage_options : optional, dict Extra options that make sense for a particular storage connection. This is used to store connection parameters like credentials, endpoint, etc. + storage_options_provider : StorageOptionsProvider, optional + A provider for dynamic storage options with automatic credential refresh. enable_v2_manifest_paths : bool, optional If True, and this is a new dataset, uses the new V2 manifest paths. These paths provide more efficient opening of datasets with many @@ -3191,6 +3195,7 @@ def commit( operation, commit_lock, storage_options=storage_options, + storage_options_provider=storage_options_provider, enable_v2_manifest_paths=enable_v2_manifest_paths, detached=detached, max_retries=max_retries, @@ -3202,6 +3207,7 @@ def commit( read_version, commit_lock, storage_options=storage_options, + storage_options_provider=storage_options_provider, enable_v2_manifest_paths=enable_v2_manifest_paths, detached=detached, max_retries=max_retries, @@ -3227,6 +3233,7 @@ def commit_batch( transactions: Sequence[Transaction], commit_lock: Optional[CommitLock] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider: Optional["StorageOptionsProvider"] = None, enable_v2_manifest_paths: Optional[bool] = None, detached: Optional[bool] = False, max_retries: int = 20, @@ -3255,6 +3262,8 @@ def commit_batch( storage_options : optional, dict Extra options that make sense for a particular storage connection. This is used to store connection parameters like credentials, endpoint, etc. + storage_options_provider : StorageOptionsProvider, optional + A provider for dynamic storage options with automatic credential refresh. enable_v2_manifest_paths : bool, optional If True, and this is a new dataset, uses the new V2 manifest paths. These paths provide more efficient opening of datasets with many @@ -3301,6 +3310,7 @@ def commit_batch( transactions, commit_lock, storage_options=storage_options, + storage_options_provider=storage_options_provider, enable_v2_manifest_paths=enable_v2_manifest_paths, detached=detached, max_retries=max_retries, @@ -5097,6 +5107,7 @@ def write_dataset( target_bases: Optional[List[str]] = None, namespace: Optional[LanceNamespace] = None, table_id: Optional[List[str]] = None, + ignore_namespace_table_storage_options: bool = False, ) -> LanceDataset: """Write a given data_obj to the given uri @@ -5198,15 +5209,22 @@ def write_dataset( table_id : optional, List[str] The table identifier when using a namespace (e.g., ["my_table"]). Must be provided together with `namespace`. Cannot be used with `uri`. + ignore_namespace_table_storage_options : bool, default False + If True, ignore the storage options returned by the namespace and only use + the provided `storage_options` parameter. The storage options provider will + not be created, so credentials will not be automatically refreshed. + This is useful when you want to use your own credentials instead of the + namespace-provided credentials. Notes ----- When using `namespace` and `table_id`: - The `uri` parameter is optional and will be fetched from the namespace - A `LanceNamespaceStorageOptionsProvider` will be created automatically for - storage options refresh + storage options refresh (unless `ignore_namespace_table_storage_options=True`) - Initial storage options from describe_table() will be merged with - any provided `storage_options` + any provided `storage_options` (unless + `ignore_namespace_table_storage_options=True`) """ # Validate that user provides either uri OR (namespace + table_id), not both has_uri = uri is not None @@ -5229,23 +5247,62 @@ def write_dataset( "Both 'namespace' and 'table_id' must be provided together." ) - request = DescribeTableRequest(id=table_id, version=None) - response = namespace.describe_table(request) + # Implement write_into_namespace logic in Python + # This follows the same pattern as the Rust implementation: + # - CREATE mode: calls namespace.create_empty_table() + # - APPEND/OVERWRITE mode: calls namespace.describe_table() + # - Both modes: create storage options provider and merge storage options + + from lance_namespace import CreateEmptyTableRequest, DescribeTableRequest + + from .namespace import LanceNamespaceStorageOptionsProvider + + # Determine which namespace method to call based on mode + if mode == "create": + request = CreateEmptyTableRequest( + id=table_id, location=None, properties=None + ) + response = namespace.create_empty_table(request) + elif mode in ("append", "overwrite"): + request = DescribeTableRequest(id=table_id, version=None) + response = namespace.describe_table(request) + else: + raise ValueError(f"Invalid mode: {mode}") + + # Get table location from response uri = response.location if not uri: - raise ValueError("Namespace did not return a table location") + raise ValueError( + f"Namespace did not return a table location in {mode} response" + ) + + # Check if we should ignore namespace storage options + if ignore_namespace_table_storage_options: + namespace_storage_options = None + else: + namespace_storage_options = response.storage_options - namespace_storage_options = response.storage_options + # Set up storage options and provider if namespace_storage_options: - # TODO: support dynamic storage options provider + # Create the storage options provider for automatic refresh + storage_options_provider = LanceNamespaceStorageOptionsProvider( + namespace=namespace, table_id=table_id + ) + + # Merge namespace storage options with any existing options + # Namespace options take precedence (same as Rust implementation) if storage_options is None: - storage_options = namespace_storage_options + storage_options = dict(namespace_storage_options) else: merged_options = dict(storage_options) merged_options.update(namespace_storage_options) storage_options = merged_options + else: + storage_options_provider = None elif table_id is not None: raise ValueError("Both 'namespace' and 'table_id' must be provided together.") + else: + storage_options_provider = None if use_legacy_format is not None: warnings.warn( @@ -5282,6 +5339,10 @@ def write_dataset( "target_bases": target_bases, } + # Add storage_options_provider if created from namespace + if storage_options_provider is not None: + params["storage_options_provider"] = storage_options_provider + if commit_lock: if not callable(commit_lock): raise TypeError(f"commit_lock must be a function, got {type(commit_lock)}") diff --git a/python/python/lance/file.py b/python/python/lance/file.py index bbd414f2180..73d308380f9 100644 --- a/python/python/lance/file.py +++ b/python/python/lance/file.py @@ -299,6 +299,7 @@ def __init__( data_cache_bytes: Optional[int] = None, version: Optional[str] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider=None, max_page_bytes: Optional[int] = None, _inner_writer: Optional[_LanceFileWriter] = None, **kwargs, @@ -325,6 +326,10 @@ def __init__( storage_options : optional, dict Extra options to be used for a particular storage connection. This is used to store connection parameters like credentials, endpoint, etc. + storage_options_provider : optional, StorageOptionsProvider + A storage options provider that can fetch and refresh storage options + dynamically. This is useful for credentials that expire and need to be + refreshed automatically. max_page_bytes : optional, int The maximum size of a page in bytes, if a single array would create a page larger than this then it will be split into multiple pages. The @@ -341,6 +346,7 @@ def __init__( data_cache_bytes=data_cache_bytes, version=version, storage_options=storage_options, + storage_options_provider=storage_options_provider, max_page_bytes=max_page_bytes, **kwargs, ) diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index d08167fcb52..99fde8fa5f8 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -864,6 +864,7 @@ def write_fragments( data_storage_version: Optional[str] = None, use_legacy_format: Optional[bool] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider=None, enable_stable_row_ids: bool = False, ) -> Transaction: ... @@ -882,6 +883,7 @@ def write_fragments( data_storage_version: Optional[str] = None, use_legacy_format: Optional[bool] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider=None, enable_stable_row_ids: bool = False, ) -> List[FragmentMetadata]: ... @@ -900,6 +902,7 @@ def write_fragments( data_storage_version: Optional[str] = None, use_legacy_format: Optional[bool] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider=None, enable_stable_row_ids: bool = False, ) -> List[FragmentMetadata] | Transaction: """ @@ -949,6 +952,10 @@ def write_fragments( storage_options : Optional[Dict[str, str]] Extra options that make sense for a particular storage connection. This is used to store connection parameters like credentials, endpoint, etc. + storage_options_provider : Optional[StorageOptionsProvider] + A storage options provider that can fetch and refresh storage options + dynamically. This is useful for credentials that expire and need to be + refreshed automatically. enable_stable_row_ids: bool Experimental: if set to true, the writer will use stable row ids. These row ids are stable after compaction operations, but not after updates. @@ -1001,6 +1008,7 @@ def write_fragments( progress=progress, data_storage_version=data_storage_version, storage_options=storage_options, + storage_options_provider=storage_options_provider, enable_stable_row_ids=enable_stable_row_ids, ) diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 670b6357f73..2fd57e307e3 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -46,6 +46,7 @@ from ..fragment import ( DataFile, FragmentMetadata, ) +from ..io import StorageOptionsProvider from ..progress import FragmentWriteProgress as FragmentWriteProgress from ..types import ReaderLike as ReaderLike from ..udf import BatchUDF as BatchUDF @@ -99,6 +100,7 @@ class LanceFileWriter: data_cache_bytes: Optional[int], version: Optional[str], storage_options: Optional[Dict[str, str]], + storage_options_provider: Optional[StorageOptionsProvider], keep_original_array: Optional[bool], max_page_bytes: Optional[int], ): ... @@ -345,6 +347,7 @@ class _Dataset: read_version: Optional[int] = None, commit_lock: Optional[CommitLock] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider: Optional[StorageOptionsProvider] = None, enable_v2_manifest_paths: Optional[bool] = None, detached: Optional[bool] = None, max_retries: Optional[int] = None, @@ -356,6 +359,7 @@ class _Dataset: transactions: Sequence[Transaction], commit_lock: Optional[CommitLock] = None, storage_options: Optional[Dict[str, str]] = None, + storage_options_provider: Optional[StorageOptionsProvider] = None, enable_v2_manifest_paths: Optional[bool] = None, detached: Optional[bool] = None, max_retries: Optional[int] = None, diff --git a/python/python/tests/test_namespace_integration.py b/python/python/tests/test_namespace_integration.py index 1ab33535330..190221d4708 100644 --- a/python/python/tests/test_namespace_integration.py +++ b/python/python/tests/test_namespace_integration.py @@ -19,7 +19,13 @@ import lance import pyarrow as pa import pytest -from lance_namespace import DescribeTableResponse, LanceNamespace +from lance_namespace import ( + CreateEmptyTableRequest, + CreateEmptyTableResponse, + DescribeTableRequest, + DescribeTableResponse, + LanceNamespace, +) # These are all keys that are accepted by storage_options CONFIG = { @@ -68,15 +74,8 @@ def delete_bucket(s3, bucket_name): pass -class MockLanceNamespace(LanceNamespace): - """ - Mock namespace implementation that tracks credential refresh calls. - - Similar to the Rust MockStorageOptionsProvider, this implementation: - - Returns incrementing credentials on each describe_table call - - Tracks the number of times describe_table has been called - - Returns credentials with short expiration times for testing refresh - """ +class TrackingNamespace(LanceNamespace): + """Mock namespace that wraps DirectoryNamespace and tracks API calls.""" def __init__( self, @@ -84,247 +83,532 @@ def __init__( storage_options: Dict[str, str], credential_expires_in_seconds: int = 60, ): - """ - Initialize the mock namespace. - - Parameters - ---------- - bucket_name : str - The S3 bucket name where tables are stored - storage_options : Dict[str, str] - Base storage options (aws_endpoint, aws_region, etc.) - credential_expires_in_seconds : int - How long credentials should be valid (for testing refresh) - """ + from lance.namespace import DirectoryNamespace + self.bucket_name = bucket_name self.base_storage_options = storage_options self.credential_expires_in_seconds = credential_expires_in_seconds - self.call_count = 0 + self.describe_call_count = 0 + self.create_call_count = 0 self.lock = Lock() - self.tables: Dict[str, str] = {} # table_id -> location mapping - def register_table(self, table_id: list, location: str): - """Register a table in the mock namespace.""" - table_key = "/".join(table_id) - self.tables[table_key] = location + # Create underlying DirectoryNamespace with storage options + dir_props = {f"storage.{k}": v for k, v in storage_options.items()} + + if bucket_name.startswith("/") or bucket_name.startswith("file://"): + dir_props["root"] = f"{bucket_name}/namespace_root" + else: + dir_props["root"] = f"s3://{bucket_name}/namespace_root" + + self.inner = DirectoryNamespace(**dir_props) - def get_call_count(self) -> int: - """Get the number of times describe_table has been called.""" + def get_describe_call_count(self) -> int: with self.lock: - return self.call_count + return self.describe_call_count - def namespace_id(self) -> str: - """Return a unique identifier for this namespace instance.""" - return "MockLanceNamespace { }" + def get_create_call_count(self) -> int: + with self.lock: + return self.create_call_count - def describe_table(self, request) -> DescribeTableResponse: - """ - Describe a table and return storage options with incrementing credentials. + def namespace_id(self) -> str: + return f"TrackingNamespace {{ inner: {self.inner.namespace_id()} }}" - This simulates a namespace server that returns temporary AWS credentials - that expire after a short time. Each call increments the credential counter. + def _modify_storage_options( + self, storage_options: Dict[str, str], count: int + ) -> Dict[str, str]: + """Add incrementing credentials with expiration timestamp.""" + modified = copy.deepcopy(storage_options) if storage_options else {} - Parameters - ---------- - request : DescribeTableRequest - The describe table request. + modified["aws_access_key_id"] = f"AKID_{count}" + modified["aws_secret_access_key"] = f"SECRET_{count}" + modified["aws_session_token"] = f"TOKEN_{count}" + expires_at_millis = int( + (time.time() + self.credential_expires_in_seconds) * 1000 + ) + modified["expires_at_millis"] = str(expires_at_millis) - Returns - ------- - DescribeTableResponse - Response with location and storage_options - """ - table_id = request.id + return modified + def create_empty_table( + self, request: CreateEmptyTableRequest + ) -> CreateEmptyTableResponse: with self.lock: - self.call_count += 1 - count = self.call_count + self.create_call_count += 1 + count = self.create_call_count - table_key = "/".join(table_id) - if table_key not in self.tables: - raise ValueError(f"Table not found: {table_key}") - - location = self.tables[table_key] + response = self.inner.create_empty_table(request) + response.storage_options = self._modify_storage_options( + response.storage_options, count + ) - # Create storage options with incrementing credentials - storage_options = copy.deepcopy(self.base_storage_options) + return response - # Add incrementing credentials (similar to Rust MockStorageOptionsProvider) - storage_options["aws_access_key_id"] = f"AKID_{count}" - storage_options["aws_secret_access_key"] = f"SECRET_{count}" - storage_options["aws_session_token"] = f"TOKEN_{count}" + def describe_table(self, request: DescribeTableRequest) -> DescribeTableResponse: + with self.lock: + self.describe_call_count += 1 + count = self.describe_call_count - # Add expiration timestamp (current time + expires_in_seconds) - expires_at_millis = int( - (time.time() + self.credential_expires_in_seconds) * 1000 + response = self.inner.describe_table(request) + response.storage_options = self._modify_storage_options( + response.storage_options, count ) - storage_options["expires_at_millis"] = str(expires_at_millis) - return DescribeTableResponse( - location=location, - storage_options=storage_options, - ) + return response @pytest.mark.integration def test_namespace_open_dataset(s3_bucket: str): - """ - Test opening a dataset through a namespace with credential tracking. - - This test verifies that: - 1. We can create a dataset and register it with a namespace - 2. We can open the dataset through the namespace - 3. The namespace's describe_table method is called to fetch credentials - """ + """Test creating and opening datasets through namespace with credential tracking.""" storage_options = copy.deepcopy(CONFIG) - # Create a test dataset directly on S3 + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3600, + ) + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}]) table_name = uuid.uuid4().hex - table_uri = f"s3://{s3_bucket}/{table_name}.lance" + table_id = ["test_ns", table_name] - # Write dataset directly to S3 - ds = lance.write_dataset(table1, table_uri, storage_options=storage_options) - assert len(ds.versions()) == 1 - assert ds.count_rows() == 2 + assert namespace.get_create_call_count() == 0 + assert namespace.get_describe_call_count() == 0 - # Create mock namespace and register the table - namespace = MockLanceNamespace( - bucket_name=s3_bucket, - storage_options=storage_options, - credential_expires_in_seconds=60, + ds = lance.write_dataset( + table1, namespace=namespace, table_id=table_id, mode="create" ) - namespace.register_table([table_name], table_uri) - - # Open dataset through namespace (ignoring storage options from namespace) - # This should call describe_table once - assert namespace.get_call_count() == 0 + assert len(ds.versions()) == 1 + assert ds.count_rows() == 2 + assert namespace.get_create_call_count() == 1 ds_from_namespace = lance.dataset( namespace=namespace, - table_id=[table_name], + table_id=table_id, + storage_options=storage_options, ignore_namespace_table_storage_options=True, ) - # Verify describe_table was called once during open - assert namespace.get_call_count() == 1 - - # Verify we can read the data + assert namespace.get_describe_call_count() == 1 assert ds_from_namespace.count_rows() == 2 result = ds_from_namespace.to_table() assert result == table1 + # Test credential caching + call_count_before_reads = namespace.get_describe_call_count() + for _ in range(3): + assert ds_from_namespace.count_rows() == 2 + assert namespace.get_describe_call_count() == call_count_before_reads + @pytest.mark.integration def test_namespace_with_refresh(s3_bucket: str): + """Test credential refresh when credentials expire.""" storage_options = copy.deepcopy(CONFIG) - # Create a test dataset + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3, + ) + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}]) table_name = uuid.uuid4().hex - table_uri = f"s3://{s3_bucket}/{table_name}.lance" + table_id = ["test_ns", table_name] - ds = lance.write_dataset(table1, table_uri, storage_options=storage_options) - assert ds.count_rows() == 2 + assert namespace.get_create_call_count() == 0 + assert namespace.get_describe_call_count() == 0 - # Create mock namespace with very short expiration (2 seconds) - # to simulate credentials that need frequent refresh - namespace = MockLanceNamespace( - bucket_name=s3_bucket, - storage_options=storage_options, - credential_expires_in_seconds=2, # Short expiration for testing + ds = lance.write_dataset( + table1, namespace=namespace, table_id=table_id, mode="create" ) - namespace.register_table([table_name], table_uri) - - assert namespace.get_call_count() == 0 + assert ds.count_rows() == 2 + assert namespace.get_create_call_count() == 1 - # Open dataset with short refresh offset - # Storage options from namespace are used by default ds_from_namespace = lance.dataset( namespace=namespace, - table_id=[table_name], + table_id=table_id, s3_credentials_refresh_offset_seconds=1, ) - initial_call_count = namespace.get_call_count() + initial_call_count = namespace.get_describe_call_count() assert initial_call_count == 1 - - # Verify we can read the data assert ds_from_namespace.count_rows() == 2 result = ds_from_namespace.to_table() assert result == table1 - # Record call count after initial reads - call_count_after_initial_reads = namespace.get_call_count() + call_count_after_initial_reads = namespace.get_describe_call_count() - # Wait for credentials to expire - time.sleep(3) + time.sleep(5) - # Perform another read operation after expiration - # This should trigger a credential refresh since credentials have expired assert ds_from_namespace.count_rows() == 2 result2 = ds_from_namespace.to_table() assert result2 == table1 - final_call_count = namespace.get_call_count() + final_call_count = namespace.get_describe_call_count() assert final_call_count == call_count_after_initial_reads + 1 @pytest.mark.integration def test_namespace_append_through_namespace(s3_bucket: str): - """ - Test appending to a dataset opened through a namespace. - - This verifies that write operations work correctly with namespace-managed - credentials. - """ + """Test appending to dataset through namespace.""" storage_options = copy.deepcopy(CONFIG) - # Create initial dataset + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3600, + ) + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}]) table_name = uuid.uuid4().hex - table_uri = f"s3://{s3_bucket}/{table_name}.lance" + table_id = ["test_ns", table_name] - ds = lance.write_dataset(table1, table_uri, storage_options=storage_options) + assert namespace.get_create_call_count() == 0 + assert namespace.get_describe_call_count() == 0 + + ds = lance.write_dataset( + table1, namespace=namespace, table_id=table_id, mode="create" + ) assert ds.count_rows() == 1 assert len(ds.versions()) == 1 + assert namespace.get_create_call_count() == 1 + initial_describe_count = namespace.get_describe_call_count() - # Create namespace and open dataset through it - namespace = MockLanceNamespace( - bucket_name=s3_bucket, - storage_options=storage_options, - credential_expires_in_seconds=60, + table2 = pa.Table.from_pylist([{"a": 10, "b": 20}]) + ds = lance.write_dataset( + table2, namespace=namespace, table_id=table_id, mode="append" ) - namespace.register_table([table_name], table_uri) + assert ds.count_rows() == 2 + assert len(ds.versions()) == 2 + assert namespace.get_create_call_count() == 1 + assert namespace.get_describe_call_count() == initial_describe_count + 1 - # Open through namespace ds_from_namespace = lance.dataset( namespace=namespace, - table_id=[table_name], + table_id=table_id, + storage_options=storage_options, ignore_namespace_table_storage_options=True, ) - assert ds_from_namespace.count_rows() == 1 - initial_call_count = namespace.get_call_count() - assert initial_call_count == 1 + assert ds_from_namespace.count_rows() == 2 + assert len(ds_from_namespace.versions()) == 2 + assert namespace.get_describe_call_count() == initial_describe_count + 2 + + +@pytest.mark.integration +def test_namespace_write_create_mode(s3_bucket: str): + """Test writing dataset through namespace in CREATE mode.""" + storage_options = copy.deepcopy(CONFIG) + + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3600, + ) + + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}]) + table_name = uuid.uuid4().hex + + assert namespace.get_create_call_count() == 0 + assert namespace.get_describe_call_count() == 0 + + ds = lance.write_dataset( + table1, + namespace=namespace, + table_id=["test_ns", table_name], + mode="create", + ) + + assert namespace.get_create_call_count() == 1 + assert ds.count_rows() == 2 + assert len(ds.versions()) == 1 + result = ds.to_table() + assert result == table1 + + +@pytest.mark.integration +def test_namespace_write_append_mode(s3_bucket: str): + """Test writing dataset through namespace in APPEND mode.""" + storage_options = copy.deepcopy(CONFIG) + + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3600, + ) + + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}]) + table_name = uuid.uuid4().hex + table_id = ["test_ns", table_name] + + assert namespace.get_create_call_count() == 0 + assert namespace.get_describe_call_count() == 0 + + ds = lance.write_dataset( + table1, namespace=namespace, table_id=table_id, mode="create" + ) + assert ds.count_rows() == 1 + assert namespace.get_create_call_count() == 1 + assert namespace.get_describe_call_count() == 0 - # Append more data using the URI directly (not through namespace) table2 = pa.Table.from_pylist([{"a": 10, "b": 20}]) + ds = lance.write_dataset( - table2, table_uri, mode="append", storage_options=storage_options + table2, + namespace=namespace, + table_id=table_id, + mode="append", ) + + assert namespace.get_create_call_count() == 1 + describe_count_after_append = namespace.get_describe_call_count() + assert describe_count_after_append == 1 assert ds.count_rows() == 2 assert len(ds.versions()) == 2 - # Re-open through namespace to see updated data + call_count_before_reads = namespace.get_describe_call_count() + for _ in range(3): + assert ds.count_rows() == 2 + assert namespace.get_describe_call_count() == call_count_before_reads + + +@pytest.mark.integration +def test_namespace_write_overwrite_mode(s3_bucket: str): + """Test writing dataset through namespace in OVERWRITE mode.""" + storage_options = copy.deepcopy(CONFIG) + + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3600, + ) + + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}]) + table_name = uuid.uuid4().hex + table_id = ["test_ns", table_name] + + assert namespace.get_create_call_count() == 0 + assert namespace.get_describe_call_count() == 0 + + ds = lance.write_dataset( + table1, namespace=namespace, table_id=table_id, mode="create" + ) + assert ds.count_rows() == 1 + assert namespace.get_create_call_count() == 1 + assert namespace.get_describe_call_count() == 0 + + table2 = pa.Table.from_pylist([{"a": 10, "b": 20}, {"a": 100, "b": 200}]) + + ds = lance.write_dataset( + table2, + namespace=namespace, + table_id=table_id, + mode="overwrite", + ) + + assert namespace.get_create_call_count() == 1 + describe_count_after_overwrite = namespace.get_describe_call_count() + assert describe_count_after_overwrite == 1 + assert ds.count_rows() == 2 + assert len(ds.versions()) == 2 + result = ds.to_table() + assert result == table2 + + call_count_before_reads = namespace.get_describe_call_count() + for _ in range(3): + assert ds.count_rows() == 2 + assert namespace.get_describe_call_count() == call_count_before_reads + + +@pytest.mark.integration +def test_namespace_distributed_write(s3_bucket: str): + """Test distributed write pattern through namespace.""" + storage_options = copy.deepcopy(CONFIG) + + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3600, + ) + + table_name = uuid.uuid4().hex + table_id = ["test_ns", table_name] + + from lance_namespace import CreateEmptyTableRequest + + request = CreateEmptyTableRequest(id=table_id, location=None, properties=None) + response = namespace.create_empty_table(request) + + assert namespace.get_create_call_count() == 1 + assert namespace.get_describe_call_count() == 0 + + table_uri = response.location + assert table_uri is not None + + from lance.namespace import LanceNamespaceStorageOptionsProvider + + namespace_storage_options = response.storage_options + assert namespace_storage_options is not None + + storage_options_provider = LanceNamespaceStorageOptionsProvider( + namespace=namespace, table_id=table_id + ) + + merged_options = dict(storage_options) + merged_options.update(namespace_storage_options) + + from lance.fragment import write_fragments + + fragment1_data = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + fragment1 = write_fragments( + fragment1_data, + table_uri, + storage_options=merged_options, + storage_options_provider=storage_options_provider, + ) + + fragment2_data = pa.Table.from_pylist([{"a": 10, "b": 20}, {"a": 30, "b": 40}]) + fragment2 = write_fragments( + fragment2_data, + table_uri, + storage_options=merged_options, + storage_options_provider=storage_options_provider, + ) + + fragment3_data = pa.Table.from_pylist([{"a": 100, "b": 200}]) + fragment3 = write_fragments( + fragment3_data, + table_uri, + storage_options=merged_options, + storage_options_provider=storage_options_provider, + ) + + all_fragments = fragment1 + fragment2 + fragment3 + + operation = lance.LanceOperation.Overwrite(fragment1_data.schema, all_fragments) + + ds = lance.LanceDataset.commit( + table_uri, + operation, + storage_options=merged_options, + storage_options_provider=storage_options_provider, + ) + + assert ds.count_rows() == 5 + assert len(ds.versions()) == 1 + + result = ds.to_table().sort_by("a") + expected = pa.Table.from_pylist( + [ + {"a": 1, "b": 2}, + {"a": 3, "b": 4}, + {"a": 10, "b": 20}, + {"a": 30, "b": 40}, + {"a": 100, "b": 200}, + ] + ) + assert result == expected + ds_from_namespace = lance.dataset( namespace=namespace, - table_id=[table_name], - ignore_namespace_table_storage_options=True, + table_id=table_id, ) + assert ds_from_namespace.count_rows() == 5 - assert ds_from_namespace.count_rows() == 2 - assert len(ds_from_namespace.versions()) == 2 - # Describe_table should have been called again - assert namespace.get_call_count() == initial_call_count + 1 +@pytest.mark.integration +def test_file_writer_with_storage_options_provider(s3_bucket: str): + """Test LanceFileWriter with storage_options_provider and credential refresh.""" + from lance import LanceNamespaceStorageOptionsProvider + from lance.file import LanceFileReader, LanceFileWriter + + storage_options = copy.deepcopy(CONFIG) + + namespace = TrackingNamespace( + bucket_name=s3_bucket, + storage_options=storage_options, + credential_expires_in_seconds=3, + ) + + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}]) + table_name = uuid.uuid4().hex + table_id = ["test_ns", table_name] + + assert namespace.get_create_call_count() == 0 + assert namespace.get_describe_call_count() == 0 + + ds = lance.write_dataset( + table1, namespace=namespace, table_id=table_id, mode="create" + ) + assert ds.count_rows() == 2 + assert namespace.get_create_call_count() == 1 + + describe_response = namespace.describe_table( + DescribeTableRequest(id=table_id, version=None) + ) + namespace_storage_options = describe_response.storage_options + + provider = LanceNamespaceStorageOptionsProvider( + namespace=namespace, table_id=table_id + ) + + initial_describe_count = namespace.get_describe_call_count() + + file_uri = f"s3://{s3_bucket}/{table_name}_file_test.lance" + schema = pa.schema([pa.field("x", pa.int64()), pa.field("y", pa.int64())]) + + writer = LanceFileWriter( + file_uri, + schema=schema, + storage_options=namespace_storage_options, + storage_options_provider=provider, + ) + + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}, schema=schema) + writer.write_batch(batch) + + batch2 = pa.RecordBatch.from_pydict( + {"x": [7, 8, 9], "y": [10, 11, 12]}, schema=schema + ) + writer.write_batch(batch2) + writer.close() + + describe_count_after_write = namespace.get_describe_call_count() + assert describe_count_after_write == initial_describe_count + + reader = LanceFileReader(file_uri, storage_options=namespace_storage_options) + result = reader.read_all(batch_size=1024) + result_table = result.to_table() + assert result_table.num_rows == 6 + assert result_table.schema == schema + + expected_table = pa.table( + {"x": [1, 2, 3, 7, 8, 9], "y": [4, 5, 6, 10, 11, 12]}, schema=schema + ) + assert result_table == expected_table + + time.sleep(5) + + file_uri2 = f"s3://{s3_bucket}/{table_name}_file_test2.lance" + writer2 = LanceFileWriter( + file_uri2, + schema=schema, + storage_options=namespace_storage_options, + storage_options_provider=provider, + ) + + batch3 = pa.RecordBatch.from_pydict( + {"x": [100, 200], "y": [300, 400]}, schema=schema + ) + writer2.write_batch(batch3) + writer2.close() + + final_describe_count = namespace.get_describe_call_count() + assert final_describe_count == describe_count_after_write + 1 + + reader2 = LanceFileReader(file_uri2, storage_options=namespace_storage_options) + result2 = reader2.read_all(batch_size=1024) + result_table2 = result2.to_table() + assert result_table2.num_rows == 2 + expected_table2 = pa.table({"x": [100, 200], "y": [300, 400]}, schema=schema) + assert result_table2 == expected_table2 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 7f80766dfba..91108192c47 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -2078,13 +2078,14 @@ impl Dataset { #[allow(clippy::too_many_arguments)] #[staticmethod] - #[pyo3(signature = (dest, operation, read_version = None, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None, commit_message = None))] + #[pyo3(signature = (dest, operation, read_version = None, commit_lock = None, storage_options = None, storage_options_provider = None, enable_v2_manifest_paths = None, detached = None, max_retries = None, commit_message = None))] fn commit( dest: PyWriteDest, operation: PyLance, read_version: Option, commit_lock: Option<&Bound<'_, PyAny>>, storage_options: Option>, + storage_options_provider: Option, enable_v2_manifest_paths: Option, detached: Option, max_retries: Option, @@ -2104,6 +2105,7 @@ impl Dataset { PyLance(transaction), commit_lock, storage_options, + storage_options_provider, enable_v2_manifest_paths, detached, max_retries, @@ -2112,23 +2114,36 @@ impl Dataset { #[allow(clippy::too_many_arguments)] #[staticmethod] - #[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] + #[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, storage_options_provider = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] fn commit_transaction( dest: PyWriteDest, transaction: PyLance, commit_lock: Option<&Bound<'_, PyAny>>, storage_options: Option>, + storage_options_provider: Option, enable_v2_manifest_paths: Option, detached: Option, max_retries: Option, ) -> PyResult { - let object_store_params = - storage_options - .as_ref() - .map(|storage_options| ObjectStoreParams { - storage_options: Some(storage_options.clone()), - ..Default::default() - }); + let provider = storage_options_provider.and_then(|py_obj| { + crate::storage_options::PyStorageOptionsProvider::new(py_obj) + .ok() + .map(|py_provider| { + Arc::new( + crate::storage_options::PyStorageOptionsProviderWrapper::new(py_provider), + ) as Arc + }) + }); + + let object_store_params = if storage_options.is_some() || provider.is_some() { + Some(ObjectStoreParams { + storage_options: storage_options.clone(), + storage_options_provider: provider, + ..Default::default() + }) + } else { + None + }; let commit_handler = commit_lock .as_ref() @@ -2166,24 +2181,38 @@ impl Dataset { }) } + #[allow(clippy::too_many_arguments)] #[staticmethod] - #[pyo3(signature = (dest, transactions, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] + #[pyo3(signature = (dest, transactions, commit_lock = None, storage_options = None, storage_options_provider = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] fn commit_batch( dest: PyWriteDest, transactions: Vec>, commit_lock: Option<&Bound<'_, PyAny>>, storage_options: Option>, + storage_options_provider: Option, enable_v2_manifest_paths: Option, detached: Option, max_retries: Option, ) -> PyResult<(Self, PyLance)> { - let object_store_params = - storage_options - .as_ref() - .map(|storage_options| ObjectStoreParams { - storage_options: Some(storage_options.clone()), - ..Default::default() - }); + let provider = storage_options_provider.and_then(|py_obj| { + crate::storage_options::PyStorageOptionsProvider::new(py_obj) + .ok() + .map(|py_provider| { + Arc::new( + crate::storage_options::PyStorageOptionsProviderWrapper::new(py_provider), + ) as Arc + }) + }); + + let object_store_params = if storage_options.is_some() || provider.is_some() { + Some(ObjectStoreParams { + storage_options: storage_options.clone(), + storage_options_provider: provider, + ..Default::default() + }) + } else { + None + }; let commit_handler = commit_lock .map(|commit_lock| { @@ -2891,11 +2920,25 @@ pub fn get_write_params(options: &Bound<'_, PyDict>) -> PyResult>(options, "storage_options")? - { + let storage_options = get_dict_opt::>(options, "storage_options")?; + let storage_options_provider = + get_dict_opt::(options, "storage_options_provider")?.and_then(|py_obj| { + crate::storage_options::PyStorageOptionsProvider::new(py_obj) + .ok() + .map(|py_provider| { + Arc::new( + crate::storage_options::PyStorageOptionsProviderWrapper::new( + py_provider, + ), + ) + as Arc + }) + }); + + if storage_options.is_some() || storage_options_provider.is_some() { p.store_params = Some(ObjectStoreParams { - storage_options: Some(storage_options), + storage_options, + storage_options_provider, ..Default::default() }); } diff --git a/python/src/file.rs b/python/src/file.rs index 4dc596166a6..2dc2a31489f 100644 --- a/python/src/file.rs +++ b/python/src/file.rs @@ -232,17 +232,23 @@ pub struct LanceFileWriter { } impl LanceFileWriter { + #[allow(clippy::too_many_arguments)] async fn open( uri_or_path: String, schema: Option>, data_cache_bytes: Option, version: Option, storage_options: Option>, + storage_options_provider: Option>, keep_original_array: Option, max_page_bytes: Option, ) -> PyResult { - let (object_store, path) = - object_store_from_uri_or_path(uri_or_path, storage_options).await?; + let (object_store, path) = object_store_from_uri_or_path_with_provider( + uri_or_path, + storage_options, + storage_options_provider, + ) + .await?; Self::open_with_store( object_store, path, @@ -290,16 +296,23 @@ impl LanceFileWriter { #[pymethods] impl LanceFileWriter { #[new] - #[pyo3(signature=(path, schema=None, data_cache_bytes=None, version=None, storage_options=None, keep_original_array=None, max_page_bytes=None))] + #[pyo3(signature=(path, schema=None, data_cache_bytes=None, version=None, storage_options=None, storage_options_provider=None, keep_original_array=None, max_page_bytes=None))] + #[allow(clippy::too_many_arguments)] pub fn new( path: String, schema: Option>, data_cache_bytes: Option, version: Option, storage_options: Option>, + storage_options_provider: Option, keep_original_array: Option, max_page_bytes: Option, ) -> PyResult { + // Convert Python StorageOptionsProvider to Rust trait object + let provider = storage_options_provider + .map(crate::storage_options::py_object_to_storage_options_provider) + .transpose()?; + rt().block_on( None, Self::open( @@ -308,6 +321,7 @@ impl LanceFileWriter { data_cache_bytes, version, storage_options, + provider, keep_original_array, max_page_bytes, ), @@ -379,6 +393,14 @@ pub async fn object_store_from_uri_or_path_no_options( pub async fn object_store_from_uri_or_path( uri_or_path: impl AsRef, storage_options: Option>, +) -> PyResult<(Arc, Path)> { + object_store_from_uri_or_path_with_provider(uri_or_path, storage_options, None).await +} + +pub async fn object_store_from_uri_or_path_with_provider( + uri_or_path: impl AsRef, + storage_options: Option>, + storage_options_provider: Option>, ) -> PyResult<(Arc, Path)> { if let Ok(mut url) = Url::parse(uri_or_path.as_ref()) { if url.scheme().len() > 1 { @@ -390,12 +412,15 @@ pub async fn object_store_from_uri_or_path( let object_store_registry = Arc::new(lance::io::ObjectStoreRegistry::default()); let object_store_params = - storage_options - .as_ref() - .map(|storage_options| ObjectStoreParams { - storage_options: Some(storage_options.clone()), + if storage_options.is_some() || storage_options_provider.is_some() { + Some(ObjectStoreParams { + storage_options: storage_options.clone(), + storage_options_provider, ..Default::default() - }); + }) + } else { + None + }; let (object_store, dir_path) = ObjectStore::from_uri_and_params( object_store_registry, diff --git a/rust/lance-namespace-impls/src/dir.rs b/rust/lance-namespace-impls/src/dir.rs index 3b9e1807172..fd5a63a0848 100644 --- a/rust/lance-namespace-impls/src/dir.rs +++ b/rust/lance-namespace-impls/src/dir.rs @@ -2639,4 +2639,115 @@ mod tests { let err_msg = result.unwrap_err().to_string(); assert!(err_msg.contains("Path traversal is not allowed")); } + + #[tokio::test] + async fn test_namespace_write() { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field as ArrowField, Schema as ArrowSchema}; + use arrow::record_batch::{RecordBatch, RecordBatchIterator}; + use lance::dataset::{Dataset, WriteMode, WriteParams}; + use lance_namespace::LanceNamespace; + + let (namespace, _temp_dir) = create_test_namespace().await; + let namespace = Arc::new(namespace) as Arc; + + // Use child namespace instead of root + let table_id = vec!["test_ns".to_string(), "test_table".to_string()]; + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new("b", DataType::Int32, false), + ])); + + // Test 1: CREATE mode + let data1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ], + ) + .unwrap(); + + let reader1 = RecordBatchIterator::new(vec![data1].into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write_into_namespace( + reader1, + namespace.clone(), + table_id.clone(), + None, + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.count_rows(None).await.unwrap(), 3); + assert_eq!(dataset.version().version, 1); + + // Test 2: APPEND mode + let data2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![4, 5])), + Arc::new(Int32Array::from(vec![40, 50])), + ], + ) + .unwrap(); + + let params_append = WriteParams { + mode: WriteMode::Append, + ..Default::default() + }; + + let reader2 = RecordBatchIterator::new(vec![data2].into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write_into_namespace( + reader2, + namespace.clone(), + table_id.clone(), + Some(params_append), + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.count_rows(None).await.unwrap(), 5); + assert_eq!(dataset.version().version, 2); + + // Test 3: OVERWRITE mode + let data3 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![100, 200])), + Arc::new(Int32Array::from(vec![1000, 2000])), + ], + ) + .unwrap(); + + let params_overwrite = WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }; + + let reader3 = RecordBatchIterator::new(vec![data3].into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write_into_namespace( + reader3, + namespace.clone(), + table_id.clone(), + Some(params_overwrite), + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.count_rows(None).await.unwrap(), 2); + assert_eq!(dataset.version().version, 3); + + // Verify old data was replaced + let result = dataset.scan().try_into_batch().await.unwrap(); + let a_col = result + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_col.values(), &[100, 200]); + } } diff --git a/rust/lance-namespace-impls/src/dir/manifest.rs b/rust/lance-namespace-impls/src/dir/manifest.rs index 0944c629d61..d95e8118f6f 100644 --- a/rust/lance-namespace-impls/src/dir/manifest.rs +++ b/rust/lance-namespace-impls/src/dir/manifest.rs @@ -13,7 +13,7 @@ use async_trait::async_trait; use bytes::Bytes; use futures::stream::StreamExt; use lance::dataset::optimize::{compact_files, CompactionOptions}; -use lance::dataset::WriteParams; +use lance::dataset::{builder::DatasetBuilder, WriteParams}; use lance::session::Session; use lance::{dataset::scanner::Scanner, Dataset}; use lance_core::{box_error, Error, Result}; @@ -21,7 +21,7 @@ use lance_index::optimize::OptimizeOptions; use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams}; use lance_index::traits::DatasetIndexExt; use lance_index::IndexType; -use lance_io::object_store::ObjectStore; +use lance_io::object_store::{ObjectStore, ObjectStoreParams}; use lance_namespace::models::{ CreateEmptyTableRequest, CreateEmptyTableResponse, CreateNamespaceRequest, CreateNamespaceResponse, CreateTableRequest, CreateTableResponse, DeregisterTableRequest, @@ -256,7 +256,7 @@ impl ManifestNamespace { inline_optimization_enabled: bool, ) -> Result { let manifest_dataset = - Self::create_or_get_manifest(&root, object_store.clone(), session.clone()).await?; + Self::create_or_get_manifest(&root, &storage_options, session.clone()).await?; Ok(Self { root, @@ -932,11 +932,21 @@ impl ManifestNamespace { /// Create or get the manifest dataset async fn create_or_get_manifest( root: &str, - _object_store: Arc, - _session: Option>, + storage_options: &Option>, + session: Option>, ) -> Result { let manifest_path = format!("{}/{}", root, MANIFEST_TABLE_NAME); - let dataset_result = Dataset::open(&manifest_path).await; + let mut builder = DatasetBuilder::from_uri(&manifest_path); + + if let Some(sess) = session.clone() { + builder = builder.with_session(sess); + } + + if let Some(opts) = storage_options { + builder = builder.with_storage_options(opts.clone()); + } + + let dataset_result = builder.load().await; if let Ok(dataset) = dataset_result { Ok(DatasetConsistencyWrapper::new(dataset)) @@ -945,7 +955,16 @@ impl ManifestNamespace { let schema = Self::manifest_schema(); let empty_batch = RecordBatch::new_empty(schema.clone()); let reader = RecordBatchIterator::new(vec![Ok(empty_batch)], schema.clone()); - let write_params = WriteParams::default(); + + let write_params = WriteParams { + session, + store_params: storage_options.as_ref().map(|opts| ObjectStoreParams { + storage_options: Some(opts.clone()), + ..Default::default() + }), + ..Default::default() + }; + let dataset = Dataset::write(Box::new(reader), &manifest_path, Some(write_params)) .await .map_err(|e| Error::IO { diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs index ab2fddef8cb..1f7ee341d26 100644 --- a/rust/lance-namespace-impls/src/rest.rs +++ b/rust/lance-namespace-impls/src/rest.rs @@ -294,6 +294,7 @@ fn convert_api_error(err: lance_namespace::apis::Error) - /// # Ok(()) /// # } /// ``` +#[derive(Clone)] pub struct RestNamespace { delimiter: String, reqwest_config: Configuration, diff --git a/rust/lance-namespace-impls/src/rest_adapter.rs b/rust/lance-namespace-impls/src/rest_adapter.rs index 7aec0b912bd..2f454d8cd2a 100644 --- a/rust/lance-namespace-impls/src/rest_adapter.rs +++ b/rust/lance-namespace-impls/src/rest_adapter.rs @@ -1813,5 +1813,116 @@ mod tests { describe_response.location ); } + + #[tokio::test] + async fn test_namespace_write() { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field as ArrowField, Schema as ArrowSchema}; + use arrow::record_batch::{RecordBatch, RecordBatchIterator}; + use lance::dataset::{Dataset, WriteMode, WriteParams}; + use lance_namespace::LanceNamespace; + + let fixture = RestServerFixture::new(4024).await; + let namespace = Arc::new(fixture.namespace.clone()) as Arc; + + // Use child namespace instead of root + let table_id = vec!["test_ns".to_string(), "test_table".to_string()]; + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new("b", DataType::Int32, false), + ])); + + // Test 1: CREATE mode + let data1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ], + ) + .unwrap(); + + let reader1 = RecordBatchIterator::new(vec![data1].into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write_into_namespace( + reader1, + namespace.clone(), + table_id.clone(), + None, + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.count_rows(None).await.unwrap(), 3); + assert_eq!(dataset.version().version, 1); + + // Test 2: APPEND mode + let data2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![4, 5])), + Arc::new(Int32Array::from(vec![40, 50])), + ], + ) + .unwrap(); + + let params_append = WriteParams { + mode: WriteMode::Append, + ..Default::default() + }; + + let reader2 = RecordBatchIterator::new(vec![data2].into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write_into_namespace( + reader2, + namespace.clone(), + table_id.clone(), + Some(params_append), + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.count_rows(None).await.unwrap(), 5); + assert_eq!(dataset.version().version, 2); + + // Test 3: OVERWRITE mode + let data3 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![100, 200])), + Arc::new(Int32Array::from(vec![1000, 2000])), + ], + ) + .unwrap(); + + let params_overwrite = WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }; + + let reader3 = RecordBatchIterator::new(vec![data3].into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write_into_namespace( + reader3, + namespace.clone(), + table_id.clone(), + Some(params_overwrite), + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.count_rows(None).await.unwrap(), 2); + assert_eq!(dataset.version().version, 3); + + // Verify old data was replaced + let result = dataset.scan().try_into_batch().await.unwrap(); + let a_col = result + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_col.values(), &[100, 200]); + } } } diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 4329ffdf003..89e9af18ce7 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -31,8 +31,11 @@ use lance_file::datatypes::populate_schema_dictionary; use lance_file::reader::FileReaderOptions; use lance_file::version::LanceFileVersion; use lance_index::DatasetIndexExt; -use lance_io::object_store::{ObjectStore, ObjectStoreParams}; +use lance_io::object_store::{ + LanceNamespaceStorageOptionsProvider, ObjectStore, ObjectStoreParams, +}; use lance_io::utils::{read_last_block, read_message, read_metadata_offset, read_struct}; +use lance_namespace::LanceNamespace; use lance_table::format::{ pb, DataFile, DataStorageFormat, DeletionFile, Fragment, IndexMetadata, Manifest, }; @@ -104,6 +107,7 @@ pub use blob::BlobFile; use hash_joiner::HashJoiner; use lance_core::box_error; pub use lance_core::ROW_ID; +use lance_namespace::models::{CreateEmptyTableRequest, DescribeTableRequest}; use lance_table::feature_flags::{apply_feature_flags, can_read_dataset}; pub use schema_evolution::{ BatchInfo, BatchUDF, ColumnAlteration, NewColumnTransform, UDFCheckpointStore, @@ -792,6 +796,143 @@ impl Dataset { .await } + /// Write into a namespace-managed table with automatic credential vending. + /// + /// For CREATE mode, calls create_empty_table() to initialize the table. + /// For other modes, calls describe_table() and opens dataset with namespace credentials. + /// + /// # Arguments + /// + /// * `batches` - The record batches to write + /// * `namespace` - The namespace to use for table management + /// * `table_id` - The table identifier + /// * `params` - Write parameters + /// * `ignore_namespace_table_storage_options` - If true, ignore storage options returned + /// by the namespace and only use the storage options in params. The storage options + /// provider will not be created, so credentials will not be automatically refreshed. + pub async fn write_into_namespace( + batches: impl RecordBatchReader + Send + 'static, + namespace: Arc, + table_id: Vec, + mut params: Option, + ignore_namespace_table_storage_options: bool, + ) -> Result { + let mut write_params = params.take().unwrap_or_default(); + + match write_params.mode { + WriteMode::Create => { + let request = CreateEmptyTableRequest { + id: Some(table_id.clone()), + location: None, + properties: None, + }; + let response = + namespace + .create_empty_table(request) + .await + .map_err(|e| Error::Namespace { + source: Box::new(e), + location: location!(), + })?; + + let uri = response.location.ok_or_else(|| Error::Namespace { + source: Box::new(std::io::Error::other( + "Table location not found in create_empty_table response", + )), + location: location!(), + })?; + + // Set initial credentials and provider unless ignored + if !ignore_namespace_table_storage_options { + if let Some(namespace_storage_options) = response.storage_options { + let provider = Arc::new(LanceNamespaceStorageOptionsProvider::new( + namespace, table_id, + )); + + // Merge namespace storage options with any existing options + let mut merged_options = write_params + .store_params + .as_ref() + .and_then(|p| p.storage_options.clone()) + .unwrap_or_default(); + merged_options.extend(namespace_storage_options); + + let existing_params = write_params.store_params.take().unwrap_or_default(); + write_params.store_params = Some(ObjectStoreParams { + storage_options: Some(merged_options), + storage_options_provider: Some(provider), + ..existing_params + }); + } + } + + Self::write(batches, uri.as_str(), Some(write_params)).await + } + WriteMode::Append | WriteMode::Overwrite => { + let request = DescribeTableRequest { + id: Some(table_id.clone()), + version: None, + }; + let response = + namespace + .describe_table(request) + .await + .map_err(|e| Error::Namespace { + source: Box::new(e), + location: location!(), + })?; + + let uri = response.location.ok_or_else(|| Error::Namespace { + source: Box::new(std::io::Error::other( + "Table location not found in describe_table response", + )), + location: location!(), + })?; + + // Set initial credentials and provider unless ignored + if !ignore_namespace_table_storage_options { + if let Some(namespace_storage_options) = response.storage_options { + let provider = Arc::new(LanceNamespaceStorageOptionsProvider::new( + namespace.clone(), + table_id.clone(), + )); + + // Merge namespace storage options with any existing options + let mut merged_options = write_params + .store_params + .as_ref() + .and_then(|p| p.storage_options.clone()) + .unwrap_or_default(); + merged_options.extend(namespace_storage_options); + + let existing_params = write_params.store_params.take().unwrap_or_default(); + write_params.store_params = Some(ObjectStoreParams { + storage_options: Some(merged_options), + storage_options_provider: Some(provider), + ..existing_params + }); + } + } + + // For APPEND/OVERWRITE modes, we must open the existing dataset first + // and pass it to InsertBuilder. If we pass just the URI, InsertBuilder + // assumes no dataset exists and converts the mode to CREATE. + let mut builder = DatasetBuilder::from_uri(uri.as_str()); + if let Some(ref store_params) = write_params.store_params { + if let Some(ref storage_options) = store_params.storage_options { + builder = builder.with_storage_options(storage_options.clone()); + } + if let Some(ref provider) = store_params.storage_options_provider { + builder = builder.with_storage_options_provider(provider.clone()); + } + } + let dataset = Arc::new(builder.load().await?); + + Self::write(batches, dataset, Some(write_params)).await + } + } + } + /// Append to existing [Dataset] with a stream of [RecordBatch]s /// /// Returns void result or Returns [Error] @@ -2612,6 +2753,7 @@ mod tests { use lance_arrow::FixedSizeListArrayExt; use mock_instant::thread_local::MockClock; + use crate::dataset::write::{CommitBuilder, InsertBuilder, WriteMode, WriteParams}; use arrow::array::{as_struct_array, AsArray, GenericListBuilder, GenericStringBuilder}; use arrow::compute::concat_batches; use arrow::datatypes::UInt64Type;