From f769101bbdd8062320faaab998c8319c5b75f418 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 10 Dec 2023 15:21:28 -0800 Subject: [PATCH 1/5] implement requirements validation --- pyiceberg/table/__init__.py | 58 +++++++++++++++++++- tests/table/test_init.py | 105 ++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 4768706d1d..91f725ee02 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -540,18 +540,31 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda class TableRequirement(IcebergBaseModel): type: str + @abstractmethod + def validate(self, base_metadata: TableMetadata) -> None: + """Validate the requirement against the base metadata.""" + ... + class AssertCreate(TableRequirement): """The table must not already exist; used for create transactions.""" type: Literal["assert-create"] = Field(default="assert-create") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is not None: + raise ValueError("Table already exists") + class AssertTableUUID(TableRequirement): """The table UUID must match the requirement's `uuid`.""" type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") - uuid: str + uuid: uuid.UUID + + def validate(self, base_metadata: TableMetadata) -> None: + if self.uuid != base_metadata.table_uuid: + raise ValueError(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}") class AssertRefSnapshotId(TableRequirement): @@ -564,6 +577,19 @@ class AssertRefSnapshotId(TableRequirement): ref: str snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") + def validate(self, base_metadata: TableMetadata) -> None: + snapshot_ref = base_metadata.refs.get(self.ref) + if snapshot_ref is not None: + ref_type = snapshot_ref.snapshot_ref_type + if self.snapshot_id is None: + raise ValueError(f"Requirement failed: {ref_type} {self.ref} was created concurrently") + elif self.snapshot_id != snapshot_ref.snapshot_id: + raise ValueError( + f"Requirement failed: {ref_type} {self.ref} has changed: expected id {self.snapshot_id}, found {snapshot_ref.snapshot_id}" + ) + elif self.snapshot_id is not None: + raise ValueError(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") + class AssertLastAssignedFieldId(TableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" @@ -571,6 +597,12 @@ class AssertLastAssignedFieldId(TableRequirement): type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") last_assigned_field_id: int = Field(..., alias="last-assigned-field-id") + def validate(self, base_metadata: TableMetadata) -> None: + if base_metadata.last_column_id != self.last_assigned_field_id: + raise ValueError( + f"Requirement failed: last assigned field id has changed: expected {self.last_assigned_field_id}, found {base_metadata.last_column_id}" + ) + class AssertCurrentSchemaId(TableRequirement): """The table's current schema id must match the requirement's `current-schema-id`.""" @@ -578,6 +610,12 @@ class AssertCurrentSchemaId(TableRequirement): type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") current_schema_id: int = Field(..., alias="current-schema-id") + def validate(self, base_metadata: TableMetadata) -> None: + if self.current_schema_id != base_metadata.current_schema_id: + raise ValueError( + f"Requirement failed: current schema id has changed: expected {self.current_schema_id}, found {base_metadata.current_schema_id}" + ) + class AssertLastAssignedPartitionId(TableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" @@ -585,6 +623,12 @@ class AssertLastAssignedPartitionId(TableRequirement): type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id") + def validate(self, base_metadata: TableMetadata) -> None: + if base_metadata.last_partition_id != self.last_assigned_partition_id: + raise ValueError( + f"Requirement failed: last assigned partition id has changed: expected {self.last_assigned_partition_id}, found {base_metadata.last_partition_id}" + ) + class AssertDefaultSpecId(TableRequirement): """The table's default spec id must match the requirement's `default-spec-id`.""" @@ -592,6 +636,12 @@ class AssertDefaultSpecId(TableRequirement): type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") default_spec_id: int = Field(..., alias="default-spec-id") + def validate(self, base_metadata: TableMetadata) -> None: + if self.default_spec_id != base_metadata.default_spec_id: + raise ValueError( + f"Requirement failed: default spec id has changed: expected {self.default_spec_id}, found {base_metadata.default_spec_id}" + ) + class AssertDefaultSortOrderId(TableRequirement): """The table's default sort order id must match the requirement's `default-sort-order-id`.""" @@ -599,6 +649,12 @@ class AssertDefaultSortOrderId(TableRequirement): type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") default_sort_order_id: int = Field(..., alias="default-sort-order-id") + def validate(self, base_metadata: TableMetadata) -> None: + if self.default_sort_order_id != base_metadata.default_sort_order_id: + raise ValueError( + f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}" + ) + class Namespace(IcebergRootModel[List[str]]): """Reference to one or more levels of a namespace.""" diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 8d13a82f3a..713784e8d7 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +import uuid from copy import copy from typing import Dict @@ -39,6 +40,14 @@ from pyiceberg.schema import Schema from pyiceberg.table import ( AddSnapshotUpdate, + AssertCreate, + AssertCurrentSchemaId, + AssertDefaultSortOrderId, + AssertDefaultSpecId, + AssertLastAssignedFieldId, + AssertLastAssignedPartitionId, + AssertRefSnapshotId, + AssertTableUUID, SetPropertiesUpdate, SetSnapshotRefUpdate, SnapshotRef, @@ -721,3 +730,99 @@ def test_metadata_isolation_from_illegal_updates(table_v1: Table) -> None: def test_generate_snapshot_id(table_v2: Table) -> None: assert isinstance(_generate_snapshot_id(), int) assert isinstance(table_v2.new_snapshot_id(), int) + + +def test_assert_create(table_v2: Table) -> None: + AssertCreate().validate(None) + + with pytest.raises(ValueError, match="Table already exists"): + AssertCreate().validate(table_v2.metadata) + + +def test_assert_table_uuid(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertTableUUID(uuid=base_metadata.table_uuid).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Table UUID does not match: 9c12d441-03fe-4693-9a96-a0705ddf69c2 != 9c12d441-03fe-4693-9a96-a0705ddf69c1", + ): + AssertTableUUID(uuid=uuid.UUID("9c12d441-03fe-4693-9a96-a0705ddf69c2")).validate(base_metadata) + + +def test_assert_ref_snapshot_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: SnapshotRefType.BRANCH main was created concurrently", + ): + AssertRefSnapshotId(ref="main", snapshot_id=None).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: SnapshotRefType.BRANCH main has changed: expected id 1, found 3055729675574597004", + ): + AssertRefSnapshotId(ref="main", snapshot_id=1).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: branch or tag not_exist is missing, expected 1", + ): + AssertRefSnapshotId(ref="not_exist", snapshot_id=1).validate(base_metadata) + + +def test_assert_last_assigned_field_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertLastAssignedFieldId(last_assigned_field_id=base_metadata.last_column_id).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: last assigned field id has changed: expected 1, found 3", + ): + AssertLastAssignedFieldId(last_assigned_field_id=1).validate(base_metadata) + + +def test_assert_current_schema_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertCurrentSchemaId(current_schema_id=base_metadata.current_schema_id).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: current schema id has changed: expected 2, found 1", + ): + AssertCurrentSchemaId(current_schema_id=2).validate(base_metadata) + + +def test_last_assigned_partition_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertLastAssignedPartitionId(last_assigned_partition_id=base_metadata.last_partition_id).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: last assigned partition id has changed: expected 1, found 1000", + ): + AssertLastAssignedPartitionId(last_assigned_partition_id=1).validate(base_metadata) + + +def test_assert_default_spec_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertDefaultSpecId(default_spec_id=base_metadata.default_spec_id).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: default spec id has changed: expected 1, found 0", + ): + AssertDefaultSpecId(default_spec_id=1).validate(base_metadata) + + +def test_assert_default_sort_order_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertDefaultSortOrderId(default_sort_order_id=base_metadata.default_sort_order_id).validate(base_metadata) + + with pytest.raises( + ValueError, + match="Requirement failed: default sort order id has changed: expected 1, found 3", + ): + AssertDefaultSortOrderId(default_sort_order_id=1).validate(base_metadata) From 4282d37abe734bb4a714574925cef133eb6663eb Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 10 Dec 2023 15:27:40 -0800 Subject: [PATCH 2/5] change the exception to CommitFailedException --- pyiceberg/exceptions.py | 2 +- pyiceberg/table/__init__.py | 22 +++++++++++----------- tests/table/test_init.py | 21 +++++++++++---------- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/pyiceberg/exceptions.py b/pyiceberg/exceptions.py index f555543723..64356b11a4 100644 --- a/pyiceberg/exceptions.py +++ b/pyiceberg/exceptions.py @@ -104,7 +104,7 @@ class GenericDynamoDbError(DynamoDbError): pass -class CommitFailedException(RESTError): +class CommitFailedException(Exception): """Commit failed, refresh and try again.""" diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 91f725ee02..8aceb03b95 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -44,7 +44,7 @@ from sortedcontainers import SortedList from typing_extensions import Annotated -from pyiceberg.exceptions import ResolveError, ValidationError +from pyiceberg.exceptions import CommitFailedException, ResolveError, ValidationError from pyiceberg.expressions import ( AlwaysTrue, And, @@ -553,7 +553,7 @@ class AssertCreate(TableRequirement): def validate(self, base_metadata: Optional[TableMetadata]) -> None: if base_metadata is not None: - raise ValueError("Table already exists") + raise CommitFailedException("Table already exists") class AssertTableUUID(TableRequirement): @@ -564,7 +564,7 @@ class AssertTableUUID(TableRequirement): def validate(self, base_metadata: TableMetadata) -> None: if self.uuid != base_metadata.table_uuid: - raise ValueError(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}") + raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}") class AssertRefSnapshotId(TableRequirement): @@ -582,13 +582,13 @@ def validate(self, base_metadata: TableMetadata) -> None: if snapshot_ref is not None: ref_type = snapshot_ref.snapshot_ref_type if self.snapshot_id is None: - raise ValueError(f"Requirement failed: {ref_type} {self.ref} was created concurrently") + raise CommitFailedException(f"Requirement failed: {ref_type} {self.ref} was created concurrently") elif self.snapshot_id != snapshot_ref.snapshot_id: - raise ValueError( + raise CommitFailedException( f"Requirement failed: {ref_type} {self.ref} has changed: expected id {self.snapshot_id}, found {snapshot_ref.snapshot_id}" ) elif self.snapshot_id is not None: - raise ValueError(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") + raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") class AssertLastAssignedFieldId(TableRequirement): @@ -599,7 +599,7 @@ class AssertLastAssignedFieldId(TableRequirement): def validate(self, base_metadata: TableMetadata) -> None: if base_metadata.last_column_id != self.last_assigned_field_id: - raise ValueError( + raise CommitFailedException( f"Requirement failed: last assigned field id has changed: expected {self.last_assigned_field_id}, found {base_metadata.last_column_id}" ) @@ -612,7 +612,7 @@ class AssertCurrentSchemaId(TableRequirement): def validate(self, base_metadata: TableMetadata) -> None: if self.current_schema_id != base_metadata.current_schema_id: - raise ValueError( + raise CommitFailedException( f"Requirement failed: current schema id has changed: expected {self.current_schema_id}, found {base_metadata.current_schema_id}" ) @@ -625,7 +625,7 @@ class AssertLastAssignedPartitionId(TableRequirement): def validate(self, base_metadata: TableMetadata) -> None: if base_metadata.last_partition_id != self.last_assigned_partition_id: - raise ValueError( + raise CommitFailedException( f"Requirement failed: last assigned partition id has changed: expected {self.last_assigned_partition_id}, found {base_metadata.last_partition_id}" ) @@ -638,7 +638,7 @@ class AssertDefaultSpecId(TableRequirement): def validate(self, base_metadata: TableMetadata) -> None: if self.default_spec_id != base_metadata.default_spec_id: - raise ValueError( + raise CommitFailedException( f"Requirement failed: default spec id has changed: expected {self.default_spec_id}, found {base_metadata.default_spec_id}" ) @@ -651,7 +651,7 @@ class AssertDefaultSortOrderId(TableRequirement): def validate(self, base_metadata: TableMetadata) -> None: if self.default_sort_order_id != base_metadata.default_sort_order_id: - raise ValueError( + raise CommitFailedException( f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}" ) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 713784e8d7..c2f201134c 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -22,6 +22,7 @@ import pytest from sortedcontainers import SortedList +from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import ( AlwaysTrue, And, @@ -735,7 +736,7 @@ def test_generate_snapshot_id(table_v2: Table) -> None: def test_assert_create(table_v2: Table) -> None: AssertCreate().validate(None) - with pytest.raises(ValueError, match="Table already exists"): + with pytest.raises(CommitFailedException, match="Table already exists"): AssertCreate().validate(table_v2.metadata) @@ -744,7 +745,7 @@ def test_assert_table_uuid(table_v2: Table) -> None: AssertTableUUID(uuid=base_metadata.table_uuid).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Table UUID does not match: 9c12d441-03fe-4693-9a96-a0705ddf69c2 != 9c12d441-03fe-4693-9a96-a0705ddf69c1", ): AssertTableUUID(uuid=uuid.UUID("9c12d441-03fe-4693-9a96-a0705ddf69c2")).validate(base_metadata) @@ -755,19 +756,19 @@ def test_assert_ref_snapshot_id(table_v2: Table) -> None: AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: SnapshotRefType.BRANCH main was created concurrently", ): AssertRefSnapshotId(ref="main", snapshot_id=None).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: SnapshotRefType.BRANCH main has changed: expected id 1, found 3055729675574597004", ): AssertRefSnapshotId(ref="main", snapshot_id=1).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: branch or tag not_exist is missing, expected 1", ): AssertRefSnapshotId(ref="not_exist", snapshot_id=1).validate(base_metadata) @@ -778,7 +779,7 @@ def test_assert_last_assigned_field_id(table_v2: Table) -> None: AssertLastAssignedFieldId(last_assigned_field_id=base_metadata.last_column_id).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: last assigned field id has changed: expected 1, found 3", ): AssertLastAssignedFieldId(last_assigned_field_id=1).validate(base_metadata) @@ -789,7 +790,7 @@ def test_assert_current_schema_id(table_v2: Table) -> None: AssertCurrentSchemaId(current_schema_id=base_metadata.current_schema_id).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: current schema id has changed: expected 2, found 1", ): AssertCurrentSchemaId(current_schema_id=2).validate(base_metadata) @@ -800,7 +801,7 @@ def test_last_assigned_partition_id(table_v2: Table) -> None: AssertLastAssignedPartitionId(last_assigned_partition_id=base_metadata.last_partition_id).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: last assigned partition id has changed: expected 1, found 1000", ): AssertLastAssignedPartitionId(last_assigned_partition_id=1).validate(base_metadata) @@ -811,7 +812,7 @@ def test_assert_default_spec_id(table_v2: Table) -> None: AssertDefaultSpecId(default_spec_id=base_metadata.default_spec_id).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: default spec id has changed: expected 1, found 0", ): AssertDefaultSpecId(default_spec_id=1).validate(base_metadata) @@ -822,7 +823,7 @@ def test_assert_default_sort_order_id(table_v2: Table) -> None: AssertDefaultSortOrderId(default_sort_order_id=base_metadata.default_sort_order_id).validate(base_metadata) with pytest.raises( - ValueError, + CommitFailedException, match="Requirement failed: default sort order id has changed: expected 1, found 3", ): AssertDefaultSortOrderId(default_sort_order_id=1).validate(base_metadata) From 94cfc6914cd8aeb7c9a5bd453b0ee3d8e542c203 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 10 Dec 2023 15:28:52 -0800 Subject: [PATCH 3/5] add docstring --- pyiceberg/table/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8aceb03b95..bf6102eeee 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -542,7 +542,14 @@ class TableRequirement(IcebergBaseModel): @abstractmethod def validate(self, base_metadata: TableMetadata) -> None: - """Validate the requirement against the base metadata.""" + """Validate the requirement against the base metadata. + + Args: + base_metadata: The base metadata to be validated against. + + Raises: + CommitFailedException: When the requirement is not met. + """ ... From 413935e6f1a78df07655a69591bb5197ba498991 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 10 Dec 2023 19:37:14 -0800 Subject: [PATCH 4/5] fix CI issue --- pyiceberg/table/refs.py | 4 ++++ tests/table/test_init.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyiceberg/table/refs.py b/pyiceberg/table/refs.py index 6f17880cac..df18fadd31 100644 --- a/pyiceberg/table/refs.py +++ b/pyiceberg/table/refs.py @@ -34,6 +34,10 @@ def __repr__(self) -> str: """Return the string representation of the SnapshotRefType class.""" return f"SnapshotRefType.{self.name}" + def __str__(self) -> str: + """Return the string representation of the SnapshotRefType class.""" + return self.value + class SnapshotRef(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") diff --git a/tests/table/test_init.py b/tests/table/test_init.py index c2f201134c..5beea50665 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -757,13 +757,13 @@ def test_assert_ref_snapshot_id(table_v2: Table) -> None: with pytest.raises( CommitFailedException, - match="Requirement failed: SnapshotRefType.BRANCH main was created concurrently", + match="Requirement failed: branch main was created concurrently", ): AssertRefSnapshotId(ref="main", snapshot_id=None).validate(base_metadata) with pytest.raises( CommitFailedException, - match="Requirement failed: SnapshotRefType.BRANCH main has changed: expected id 1, found 3055729675574597004", + match="Requirement failed: branch main has changed: expected id 1, found 3055729675574597004", ): AssertRefSnapshotId(ref="main", snapshot_id=1).validate(base_metadata) From 52ceaf8da0b91a9ecccd191f31c48beef3b92ab0 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 10 Dec 2023 22:42:07 -0800 Subject: [PATCH 5/5] make base_metadata optional and add null check --- pyiceberg/table/__init__.py | 45 ++++++++++++++++++++++++------------- tests/table/test_init.py | 21 +++++++++++++++++ 2 files changed, 50 insertions(+), 16 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index bf6102eeee..e4ca71f0a1 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -541,7 +541,7 @@ class TableRequirement(IcebergBaseModel): type: str @abstractmethod - def validate(self, base_metadata: TableMetadata) -> None: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: """Validate the requirement against the base metadata. Args: @@ -569,8 +569,10 @@ class AssertTableUUID(TableRequirement): type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") uuid: uuid.UUID - def validate(self, base_metadata: TableMetadata) -> None: - if self.uuid != base_metadata.table_uuid: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.uuid != base_metadata.table_uuid: raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}") @@ -584,9 +586,10 @@ class AssertRefSnapshotId(TableRequirement): ref: str snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") - def validate(self, base_metadata: TableMetadata) -> None: - snapshot_ref = base_metadata.refs.get(self.ref) - if snapshot_ref is not None: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif snapshot_ref := base_metadata.refs.get(self.ref): ref_type = snapshot_ref.snapshot_ref_type if self.snapshot_id is None: raise CommitFailedException(f"Requirement failed: {ref_type} {self.ref} was created concurrently") @@ -604,8 +607,10 @@ class AssertLastAssignedFieldId(TableRequirement): type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") last_assigned_field_id: int = Field(..., alias="last-assigned-field-id") - def validate(self, base_metadata: TableMetadata) -> None: - if base_metadata.last_column_id != self.last_assigned_field_id: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif base_metadata.last_column_id != self.last_assigned_field_id: raise CommitFailedException( f"Requirement failed: last assigned field id has changed: expected {self.last_assigned_field_id}, found {base_metadata.last_column_id}" ) @@ -617,8 +622,10 @@ class AssertCurrentSchemaId(TableRequirement): type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") current_schema_id: int = Field(..., alias="current-schema-id") - def validate(self, base_metadata: TableMetadata) -> None: - if self.current_schema_id != base_metadata.current_schema_id: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.current_schema_id != base_metadata.current_schema_id: raise CommitFailedException( f"Requirement failed: current schema id has changed: expected {self.current_schema_id}, found {base_metadata.current_schema_id}" ) @@ -630,8 +637,10 @@ class AssertLastAssignedPartitionId(TableRequirement): type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id") - def validate(self, base_metadata: TableMetadata) -> None: - if base_metadata.last_partition_id != self.last_assigned_partition_id: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif base_metadata.last_partition_id != self.last_assigned_partition_id: raise CommitFailedException( f"Requirement failed: last assigned partition id has changed: expected {self.last_assigned_partition_id}, found {base_metadata.last_partition_id}" ) @@ -643,8 +652,10 @@ class AssertDefaultSpecId(TableRequirement): type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") default_spec_id: int = Field(..., alias="default-spec-id") - def validate(self, base_metadata: TableMetadata) -> None: - if self.default_spec_id != base_metadata.default_spec_id: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.default_spec_id != base_metadata.default_spec_id: raise CommitFailedException( f"Requirement failed: default spec id has changed: expected {self.default_spec_id}, found {base_metadata.default_spec_id}" ) @@ -656,8 +667,10 @@ class AssertDefaultSortOrderId(TableRequirement): type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") default_sort_order_id: int = Field(..., alias="default-sort-order-id") - def validate(self, base_metadata: TableMetadata) -> None: - if self.default_sort_order_id != base_metadata.default_sort_order_id: + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.default_sort_order_id != base_metadata.default_sort_order_id: raise CommitFailedException( f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}" ) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 5beea50665..04d467c318 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -744,6 +744,9 @@ def test_assert_table_uuid(table_v2: Table) -> None: base_metadata = table_v2.metadata AssertTableUUID(uuid=base_metadata.table_uuid).validate(base_metadata) + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertTableUUID(uuid=uuid.UUID("9c12d441-03fe-4693-9a96-a0705ddf69c2")).validate(None) + with pytest.raises( CommitFailedException, match="Table UUID does not match: 9c12d441-03fe-4693-9a96-a0705ddf69c2 != 9c12d441-03fe-4693-9a96-a0705ddf69c1", @@ -755,6 +758,9 @@ def test_assert_ref_snapshot_id(table_v2: Table) -> None: base_metadata = table_v2.metadata AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertRefSnapshotId(ref="main", snapshot_id=1).validate(None) + with pytest.raises( CommitFailedException, match="Requirement failed: branch main was created concurrently", @@ -778,6 +784,9 @@ def test_assert_last_assigned_field_id(table_v2: Table) -> None: base_metadata = table_v2.metadata AssertLastAssignedFieldId(last_assigned_field_id=base_metadata.last_column_id).validate(base_metadata) + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertLastAssignedFieldId(last_assigned_field_id=1).validate(None) + with pytest.raises( CommitFailedException, match="Requirement failed: last assigned field id has changed: expected 1, found 3", @@ -789,6 +798,9 @@ def test_assert_current_schema_id(table_v2: Table) -> None: base_metadata = table_v2.metadata AssertCurrentSchemaId(current_schema_id=base_metadata.current_schema_id).validate(base_metadata) + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertCurrentSchemaId(current_schema_id=1).validate(None) + with pytest.raises( CommitFailedException, match="Requirement failed: current schema id has changed: expected 2, found 1", @@ -800,6 +812,9 @@ def test_last_assigned_partition_id(table_v2: Table) -> None: base_metadata = table_v2.metadata AssertLastAssignedPartitionId(last_assigned_partition_id=base_metadata.last_partition_id).validate(base_metadata) + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertLastAssignedPartitionId(last_assigned_partition_id=1).validate(None) + with pytest.raises( CommitFailedException, match="Requirement failed: last assigned partition id has changed: expected 1, found 1000", @@ -811,6 +826,9 @@ def test_assert_default_spec_id(table_v2: Table) -> None: base_metadata = table_v2.metadata AssertDefaultSpecId(default_spec_id=base_metadata.default_spec_id).validate(base_metadata) + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertDefaultSpecId(default_spec_id=1).validate(None) + with pytest.raises( CommitFailedException, match="Requirement failed: default spec id has changed: expected 1, found 0", @@ -822,6 +840,9 @@ def test_assert_default_sort_order_id(table_v2: Table) -> None: base_metadata = table_v2.metadata AssertDefaultSortOrderId(default_sort_order_id=base_metadata.default_sort_order_id).validate(base_metadata) + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertDefaultSortOrderId(default_sort_order_id=1).validate(None) + with pytest.raises( CommitFailedException, match="Requirement failed: default sort order id has changed: expected 1, found 3",