diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9aa6c1c9c5..436266fb08 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -417,12 +417,13 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta if update.last_column_id < base_metadata.last_column_id: raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") - updated_metadata_data = copy(base_metadata.model_dump()) - updated_metadata_data["last-column-id"] = update.last_column_id - updated_metadata_data["schemas"].append(update.schema_.model_dump()) - context.add_update(update) - return TableMetadataUtil.parse_obj(updated_metadata_data) + return base_metadata.model_copy( + update={ + "last_column_id": update.last_column_id, + "schemas": base_metadata.schemas + [update.schema_], + } + ) @_apply_table_update.register(SetCurrentSchemaUpdate) @@ -441,11 +442,8 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta if schema is None: raise ValueError(f"Schema with id {new_schema_id} does not exist") - updated_metadata_data = copy(base_metadata.model_dump()) - updated_metadata_data["current-schema-id"] = new_schema_id - context.add_update(update) - return TableMetadataUtil.parse_obj(updated_metadata_data) + return base_metadata.model_copy(update={"current_schema_id": new_schema_id}) @_apply_table_update.register(AddSnapshotUpdate) @@ -469,12 +467,14 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe f"older than last sequence number {base_metadata.last_sequence_number}" ) - updated_metadata_data = copy(base_metadata.model_dump()) - updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms - updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number - updated_metadata_data["snapshots"].append(update.snapshot.model_dump()) context.add_update(update) - return TableMetadataUtil.parse_obj(updated_metadata_data) + return base_metadata.model_copy( + update={ + "last_updated_ms": update.snapshot.timestamp_ms, + "last_sequence_number": update.snapshot.sequence_number, + "snapshots": base_metadata.snapshots + [update.snapshot], + } + ) @_apply_table_update.register(SetSnapshotRefUpdate) @@ -493,28 +493,27 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id) if snapshot is None: - raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") + raise ValueError(f"Cannot set {update.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") - update_metadata_data = copy(base_metadata.model_dump()) - update_last_updated_ms = True + metadata_updates: Dict[str, Any] = {} if context.is_added_snapshot(snapshot_ref.snapshot_id): - update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms - update_last_updated_ms = False + metadata_updates["last_updated_ms"] = snapshot.timestamp_ms if update.ref_name == MAIN_BRANCH: - update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id - if update_last_updated_ms: - update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) - update_metadata_data["snapshot-log"].append( + metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id + if "last_updated_ms" not in metadata_updates: + metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) + + metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [ SnapshotLogEntry( snapshot_id=snapshot_ref.snapshot_id, - timestamp_ms=update_metadata_data["last-updated-ms"], - ).model_dump() - ) + timestamp_ms=metadata_updates["last_updated_ms"], + ) + ] - update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump() + metadata_updates["refs"] = {**base_metadata.refs, update.ref_name: snapshot_ref} context.add_update(update) - return TableMetadataUtil.parse_obj(update_metadata_data) + return base_metadata.model_copy(update=metadata_updates) def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: @@ -533,7 +532,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda for update in updates: new_metadata = _apply_table_update(update, new_metadata, context) - return new_metadata + return new_metadata.model_copy(deep=True) class TableRequirement(IcebergBaseModel): diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 6d188befeb..8d13a82f3a 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +from copy import copy from typing import Dict import pytest @@ -50,7 +51,7 @@ _TableMetadataUpdateContext, update_table_metadata, ) -from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER +from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2 from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -640,9 +641,12 @@ def test_update_metadata_with_multiple_updates(table_v1: Table) -> None: ) new_metadata = update_table_metadata(base_metadata, test_updates) + # rebuild the metadata to trigger validation + new_metadata = TableMetadataUtil.parse_obj(copy(new_metadata.model_dump())) # UpgradeFormatVersionUpdate assert new_metadata.format_version == 2 + assert isinstance(new_metadata, TableMetadataV2) # UpdateSchema assert len(new_metadata.schemas) == 2 @@ -669,6 +673,51 @@ def test_update_metadata_with_multiple_updates(table_v1: Table) -> None: ) +def test_metadata_isolation_from_illegal_updates(table_v1: Table) -> None: + base_metadata = table_v1.metadata + base_metadata_backup = base_metadata.model_copy(deep=True) + + # Apply legal updates on the table metadata + transaction = table_v1.transaction() + schema_update_1 = transaction.update_schema() + schema_update_1.add_column(path="b", field_type=IntegerType()) + schema_update_1.commit() + test_updates = transaction._updates # pylint: disable=W0212 + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638573590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + ) + test_updates += ( + AddSnapshotUpdate(snapshot=new_snapshot), + SetSnapshotRefUpdate( + ref_name="main", + type="branch", + snapshot_id=25, + max_ref_age_ms=123123123, + max_snapshot_age_ms=12312312312, + min_snapshots_to_keep=1, + ), + ) + new_metadata = update_table_metadata(base_metadata, test_updates) + + # Check that the original metadata is not modified + assert base_metadata == base_metadata_backup + + # Perform illegal update on the new metadata: + # TableMetadata should be immutable, but the pydantic's frozen config cannot prevent + # operations such as list append. + new_metadata.partition_specs.append(PartitionSpec(spec_id=0)) + assert len(new_metadata.partition_specs) == 2 + + # The original metadata should not be affected by the illegal update on the new metadata + assert len(base_metadata.partition_specs) == 1 + + def test_generate_snapshot_id(table_v2: Table) -> None: assert isinstance(_generate_snapshot_id(), int) assert isinstance(table_v2.new_snapshot_id(), int)