From d53785ae113b1baf67bd256c4dccef9233415ad4 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 5 Nov 2023 00:32:47 -0700 Subject: [PATCH 01/22] Implement table metadata updater first draft --- pyiceberg/table/__init__.py | 240 +++++++++++++++++++++++++++++++++++- pyiceberg/table/metadata.py | 2 + tests/table/test_init.py | 11 +- 3 files changed, 250 insertions(+), 3 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ae35b34384..5a04c9ae7e 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -22,7 +22,7 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from functools import cached_property +from functools import cached_property, singledispatchmethod from itertools import chain from typing import ( TYPE_CHECKING, @@ -69,7 +69,14 @@ promote, visit, ) -from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata +from pyiceberg.table.metadata import ( + INITIAL_SEQUENCE_NUMBER, + SUPPORTED_TABLE_FORMAT_VERSION, + TableMetadata, + TableMetadataUtil, + TableMetadataV1, +) +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import SortOrder from pyiceberg.typedef import ( @@ -349,15 +356,205 @@ class RemovePropertiesUpdate(TableUpdate): removals: List[str] +class TableMetadataUpdateBuilder: + _base_metadata: TableMetadata + _updates: List[TableUpdate] + _last_added_schema_id: Optional[int] + + def __init__(self, base_metadata: TableMetadata) -> None: + self._base_metadata = TableMetadataUtil.parse_obj(copy(base_metadata.model_dump())) + self._updates = [] + self._last_added_schema_id = None + + def _reuse_or_create_new_schema_id(self, new_schema: Schema) -> Tuple[int, bool]: + # if the schema already exists, use its id; otherwise use the highest id + 1 + new_schema_id = self._base_metadata.current_schema_id + for schema in self._base_metadata.schemas: + if schema == new_schema: + return schema.schema_id, False + elif schema.schema_id >= new_schema_id: + new_schema_id = schema.schema_id + 1 + return new_schema_id, True + + def _add_schema_internal(self, schema: Schema, last_column_id: int, update: TableUpdate) -> int: + if last_column_id < self._base_metadata.last_column_id: + raise ValueError(f"Invalid last column id {last_column_id}, must be >= {self._base_metadata.last_column_id}") + new_schema_id, schema_found = self._reuse_or_create_new_schema_id(schema) + if schema_found and last_column_id == self._base_metadata.last_column_id: + if self._last_added_schema_id is not None and any( + update.schema_.schema_id == new_schema_id for update in self._updates if isinstance(update, AddSchemaUpdate) + ): + self._last_added_schema_id = new_schema_id + return new_schema_id + + self._base_metadata.last_column_id = last_column_id + + new_schema = ( + schema + if new_schema_id == schema.schema_id + # TODO: double check the parameter passing here, schema.fields may be interpreted as the **data fileds + else Schema(*schema.fields, schema_id=new_schema_id, identifier_field_ids=schema.identifier_field_ids) + ) + + if not schema_found: + self._base_metadata.schemas.append(new_schema) + + self._updates.append(update) + self._last_added_schema_id = new_schema_id + return new_schema_id + + def _set_current_schema(self, schema_id: int) -> None: + if schema_id == -1: + if self._last_added_schema_id is None: + raise ValueError("Cannot set current schema to last added schema when no schema has been added") + return self._set_current_schema(self._last_added_schema_id) + + if schema_id == self._base_metadata.current_schema_id: + return + + schema = next(schema for schema in self._base_metadata.schemas if schema.schema_id == schema_id) + if schema is None: + raise ValueError(f"Schema with id {schema_id} does not exist") + + # TODO: rebuild sort_order and partition_spec + # So it seems the rebuild just refresh the inner field which hold the schema and some naming check for partition_spec + # Seems this is not necessary in pyiceberg case wince + + self._base_metadata.current_schema_id = schema_id + if self._last_added_schema_id is not None and self._last_added_schema_id == schema_id: + self._updates.append(SetCurrentSchemaUpdate(schema_id=-1)) + else: + self._updates.append(SetCurrentSchemaUpdate(schema_id=schema_id)) + + @singledispatchmethod + def update_table_metadata(self, update: TableUpdate) -> None: + raise TypeError(f"Unsupported update: {update}") + + @update_table_metadata.register(UpgradeFormatVersionUpdate) + def _(self, update: UpgradeFormatVersionUpdate) -> None: + if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: + raise ValueError(f"Unsupported table format version: {update.format_version}") + if update.format_version < self._base_metadata.format_version: + raise ValueError(f"Cannot downgrade v{self._base_metadata.format_version} table to v{update.format_version}") + if update.format_version == self._base_metadata.format_version: + return + # At this point, the base_metadata is guaranteed to be v1 + if isinstance(self._base_metadata, TableMetadataV1): + self._base_metadata = self._base_metadata.to_v2() + + raise ValueError(f"Cannot upgrade v{self._base_metadata.format_version} table to v{update.format_version}") + + @update_table_metadata.register(AddSchemaUpdate) + def _(self, update: AddSchemaUpdate) -> None: + self._add_schema_internal(update.schema_, update.last_column_id, update) + + @update_table_metadata.register(SetCurrentSchemaUpdate) + def _(self, update: SetCurrentSchemaUpdate) -> None: + self._set_current_schema(update.schema_id) + + @update_table_metadata.register(AddSnapshotUpdate) + def _(self, update: AddSnapshotUpdate) -> None: + if len(self._base_metadata.schemas) == 0: + raise ValueError("Attempting to add a snapshot before a schema is added") + if len(self._base_metadata.partition_specs) == 0: + raise ValueError("Attempting to add a snapshot before a partition spec is added") + if len(self._base_metadata.sort_orders) == 0: + raise ValueError("Attempting to add a snapshot before a sort order is added") + if any(update.snapshot.snapshot_id == snapshot.snapshot_id for snapshot in self._base_metadata.snapshots): + raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") + if ( + self._base_metadata.format_version == 2 + and update.snapshot.sequence_number is not None + and self._base_metadata.last_sequence_number is not None + and update.snapshot.sequence_number <= self._base_metadata.last_sequence_number + and update.snapshot.parent_snapshot_id is not None + ): + raise ValueError( + f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} older than last sequence number {self._base_metadata.last_sequence_number}" + ) + + self._base_metadata.last_updated_ms = update.snapshot.timestamp + self._base_metadata.last_sequence_number = update.snapshot.sequence_number + self._base_metadata.snapshots.append(update.snapshot) + self._updates.append(update) + + @update_table_metadata.register(SetSnapshotRefUpdate) + def _(self, update: SetSnapshotRefUpdate) -> None: + ## TODO: may be some of the validation could be added to SnapshotRef class + ## TODO: may be we need to make some of the field in this update as optional or we can remove some of the checks + if update.type is None: + raise ValueError("Snapshot ref type must be set") + if update.min_snapshots_to_keep is not None and update.type == SnapshotRefType.TAG: + raise ValueError("Cannot set min snapshots to keep for branch refs") + if update.min_snapshots_to_keep is not None and update.min_snapshots_to_keep <= 0: + raise ValueError("Minimum snapshots to keep must be >= 0") + if update.max_snapshot_age_ms is not None and update.type == SnapshotRefType.TAG: + raise ValueError("Tags do not support setting maxSnapshotAgeMs") + if update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: + raise ValueError("Max snapshot age must be > 0 ms") + if update.max_age_ref_ms is not None and update.max_age_ref_ms <= 0: + raise ValueError("Max ref age must be > 0 ms") + snapshot_ref = SnapshotRef( + snapshot_id=update.snapshot_id, + snapshot_ref_type=update.type, + min_snapshots_to_keep=update.min_snapshots_to_keep, + max_snapshot_age_ms=update.max_snapshot_age_ms, + max_ref_age_ms=update.max_age_ref_ms, + ) + existing_ref = self._base_metadata.refs.get(snapshot_ref.ref_name) + if existing_ref is not None and existing_ref == snapshot_ref: + return + + snapshot = next( + snapshot for snapshot in self._base_metadata.snapshots if snapshot.snapshot_id == snapshot_ref.snapshot_id + ) + if snapshot is None: + raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") + + if any( + snapshot_ref.snapshot_id == prev_update.snapshot.snapshot_id + for prev_update in self._updates + if isinstance(self._updates, AddSnapshotUpdate) + ): + self._base_metadata.last_updated_ms = snapshot.timestamp + + if snapshot_ref.ref_name == MAIN_BRANCH: + self._base_metadata.current_snapshot_id = snapshot_ref.snapshot_id + # TODO: double-check if the default value of TableMetadata make the timestamp too early + # if self._base_metadata.last_updated_ms is None: + # self._base_metadata.last_updated_ms = datetime_to_millis(datetime.datetime.now().astimezone()) + self._base_metadata.snapshot_log.append( + SnapshotLogEntry( + snapshot_id=snapshot_ref.snapshot_id, + timestamp_ms=self._base_metadata.last_updated_ms, + ) + ) + + self._base_metadata.refs[snapshot_ref.ref_name] = snapshot_ref + self._updates.append(update) + + def build(self) -> TableMetadata: + return TableMetadataUtil.parse_obj(self._base_metadata.model_dump()) + + class TableRequirement(IcebergBaseModel): type: str + @abstractmethod + def validate(self, base_metadata: TableMetadata) -> None: + """Validate the requirement against the base metadata.""" + ... + class AssertCreate(TableRequirement): """The table must not already exist; used for create transactions.""" type: Literal["assert-create"] = Field(default="assert-create") + def validate(self, base_metadata: TableMetadata) -> None: + if base_metadata is not None: + raise ValueError("Table already exists") + class AssertTableUUID(TableRequirement): """The table UUID must match the requirement's `uuid`.""" @@ -365,6 +562,10 @@ class AssertTableUUID(TableRequirement): type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") uuid: str + def validate(self, base_metadata: TableMetadata) -> None: + if self.uuid != base_metadata.uuid: + raise ValueError(f"Table UUID does not match: {self.uuid} != {base_metadata.uuid}") + class AssertRefSnapshotId(TableRequirement): """The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`. @@ -376,6 +577,19 @@ class AssertRefSnapshotId(TableRequirement): ref: str snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") + def validate(self, base_metadata: TableMetadata) -> None: + snapshot_ref = base_metadata.refs.get(self.ref) + if snapshot_ref is not None: + ref_type = snapshot_ref.snapshot_ref_type + if self.snapshot_id is None: + raise ValueError(f"Requirement failed: {self.ref_tpe} {self.ref} was created concurrently") + elif self.snapshot_id != snapshot_ref.snapshot_id: + raise ValueError( + f"Requirement failed: {ref_type} {self.ref} has changed: expected id {self.snapshot_id}, found {snapshot_ref.snapshot_id}" + ) + elif self.snapshot_id is not None: + raise ValueError(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") + class AssertLastAssignedFieldId(TableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" @@ -383,6 +597,9 @@ class AssertLastAssignedFieldId(TableRequirement): type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") last_assigned_field_id: int = Field(..., alias="last-assigned-field-id") + def validate(self, base_metadata: TableMetadata) -> None: + raise NotImplementedError("Not yet implemented") + class AssertCurrentSchemaId(TableRequirement): """The table's current schema id must match the requirement's `current-schema-id`.""" @@ -390,6 +607,9 @@ class AssertCurrentSchemaId(TableRequirement): type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") current_schema_id: int = Field(..., alias="current-schema-id") + def validate(self, base_metadata: TableMetadata) -> None: + raise NotImplementedError("Not yet implemented") + class AssertLastAssignedPartitionId(TableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" @@ -397,6 +617,9 @@ class AssertLastAssignedPartitionId(TableRequirement): type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id") + def validate(self, base_metadata: TableMetadata) -> None: + raise NotImplementedError("Not yet implemented") + class AssertDefaultSpecId(TableRequirement): """The table's default spec id must match the requirement's `default-spec-id`.""" @@ -404,6 +627,9 @@ class AssertDefaultSpecId(TableRequirement): type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") default_spec_id: int = Field(..., alias="default-spec-id") + def validate(self, base_metadata: TableMetadata) -> None: + raise NotImplementedError("Not yet implemented") + class AssertDefaultSortOrderId(TableRequirement): """The table's default sort order id must match the requirement's `default-sort-order-id`.""" @@ -411,6 +637,9 @@ class AssertDefaultSortOrderId(TableRequirement): type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") default_sort_order_id: int = Field(..., alias="default-sort-order-id") + def validate(self, base_metadata: TableMetadata) -> None: + raise NotImplementedError("Not yet implemented") + class Namespace(IcebergRootModel[List[str]]): """Reference to one or more levels of a namespace.""" @@ -439,6 +668,13 @@ class CommitTableResponse(IcebergBaseModel): metadata_location: str = Field(alias="metadata-location") +def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: + builder = TableMetadataUpdateBuilder(base_metadata) + for update in updates: + builder.update_table_metadata(update) + return builder.build() + + class Table: identifier: Identifier = Field() metadata: TableMetadata diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 73d76d8606..271d40e25a 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -69,6 +69,8 @@ INITIAL_SPEC_ID = 0 DEFAULT_SCHEMA_ID = 0 +SUPPORTED_TABLE_FORMAT_VERSION = 2 + def cleanup_snapshot_id(data: Dict[str, Any]) -> Dict[str, Any]: """Run before validation.""" diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 369df4fa92..a5a3e3e833 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -42,7 +42,7 @@ Table, UpdateSchema, _generate_snapshot_id, - _match_deletes_to_datafile, + _match_deletes_to_datafile, update_table_metadata, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER from pyiceberg.table.snapshots import ( @@ -508,6 +508,15 @@ def test_add_nested_list_type_column(table: Table) -> None: ) assert new_schema.highest_field_id == 7 +def test_update_metadata_table_schema(table: Table) -> None: + transaction = table.transaction() + update = transaction.update_schema() + update.add_column(path="b", field_type=IntegerType()) + update.commit() + + new_metadata = update_table_metadata(table.metadata, transaction._updates) # pylint: disable=W0212 + print(new_metadata) + def test_generate_snapshot_id(table: Table) -> None: assert isinstance(_generate_snapshot_id(), int) From 274b91bb9d37570c722f6e49117a7065953700a9 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 5 Nov 2023 01:56:32 -0800 Subject: [PATCH 02/22] fix updater error and add tests --- pyiceberg/table/__init__.py | 100 ++++++++++++++++++++---------------- tests/table/test_init.py | 59 +++++++++++++++++++-- 2 files changed, 112 insertions(+), 47 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5a04c9ae7e..203825fbae 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import datetime import itertools import uuid from abc import ABC, abstractmethod @@ -74,7 +75,6 @@ SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil, - TableMetadataV1, ) from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry @@ -96,6 +96,7 @@ StructType, ) from pyiceberg.utils.concurrent import ExecutorFactory +from pyiceberg.utils.datetime import datetime_to_millis if TYPE_CHECKING: import pandas as pd @@ -357,19 +358,20 @@ class RemovePropertiesUpdate(TableUpdate): class TableMetadataUpdateBuilder: - _base_metadata: TableMetadata + _base_metadata: Dict[str, Any] _updates: List[TableUpdate] _last_added_schema_id: Optional[int] def __init__(self, base_metadata: TableMetadata) -> None: - self._base_metadata = TableMetadataUtil.parse_obj(copy(base_metadata.model_dump())) + self._base_metadata = copy(base_metadata.model_dump()) self._updates = [] self._last_added_schema_id = None def _reuse_or_create_new_schema_id(self, new_schema: Schema) -> Tuple[int, bool]: # if the schema already exists, use its id; otherwise use the highest id + 1 - new_schema_id = self._base_metadata.current_schema_id - for schema in self._base_metadata.schemas: + new_schema_id = self._base_metadata["current-schema-id"] + for raw_schema in self._base_metadata["schemas"]: + schema = Schema(**raw_schema) if schema == new_schema: return schema.schema_id, False elif schema.schema_id >= new_schema_id: @@ -377,17 +379,17 @@ def _reuse_or_create_new_schema_id(self, new_schema: Schema) -> Tuple[int, bool] return new_schema_id, True def _add_schema_internal(self, schema: Schema, last_column_id: int, update: TableUpdate) -> int: - if last_column_id < self._base_metadata.last_column_id: - raise ValueError(f"Invalid last column id {last_column_id}, must be >= {self._base_metadata.last_column_id}") - new_schema_id, schema_found = self._reuse_or_create_new_schema_id(schema) - if schema_found and last_column_id == self._base_metadata.last_column_id: + if last_column_id < self._base_metadata["last-column-id"]: + raise ValueError(f"Invalid last column id {last_column_id}, must be >= {self._base_metadata['last-column-id']}") + new_schema_id, is_new_schema = self._reuse_or_create_new_schema_id(schema) + if not is_new_schema and last_column_id == self._base_metadata["last-column-id"]: if self._last_added_schema_id is not None and any( update.schema_.schema_id == new_schema_id for update in self._updates if isinstance(update, AddSchemaUpdate) ): self._last_added_schema_id = new_schema_id return new_schema_id - self._base_metadata.last_column_id = last_column_id + self._base_metadata["last-column-id"] = last_column_id new_schema = ( schema @@ -396,8 +398,8 @@ def _add_schema_internal(self, schema: Schema, last_column_id: int, update: Tabl else Schema(*schema.fields, schema_id=new_schema_id, identifier_field_ids=schema.identifier_field_ids) ) - if not schema_found: - self._base_metadata.schemas.append(new_schema) + if is_new_schema: + self._base_metadata["schemas"].append(new_schema.model_dump()) self._updates.append(update) self._last_added_schema_id = new_schema_id @@ -409,10 +411,12 @@ def _set_current_schema(self, schema_id: int) -> None: raise ValueError("Cannot set current schema to last added schema when no schema has been added") return self._set_current_schema(self._last_added_schema_id) - if schema_id == self._base_metadata.current_schema_id: + if schema_id == self._base_metadata["current-schema-id"]: return - schema = next(schema for schema in self._base_metadata.schemas if schema.schema_id == schema_id) + schema = next( + (Schema(**raw_schema) for raw_schema in self._base_metadata["schemas"] if raw_schema["schema-id"] == schema_id), None + ) if schema is None: raise ValueError(f"Schema with id {schema_id} does not exist") @@ -420,7 +424,7 @@ def _set_current_schema(self, schema_id: int) -> None: # So it seems the rebuild just refresh the inner field which hold the schema and some naming check for partition_spec # Seems this is not necessary in pyiceberg case wince - self._base_metadata.current_schema_id = schema_id + self._base_metadata["current-schema-id"] = schema_id if self._last_added_schema_id is not None and self._last_added_schema_id == schema_id: self._updates.append(SetCurrentSchemaUpdate(schema_id=-1)) else: @@ -432,17 +436,17 @@ def update_table_metadata(self, update: TableUpdate) -> None: @update_table_metadata.register(UpgradeFormatVersionUpdate) def _(self, update: UpgradeFormatVersionUpdate) -> None: + current_format_version = self._base_metadata["format-version"] if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: raise ValueError(f"Unsupported table format version: {update.format_version}") - if update.format_version < self._base_metadata.format_version: - raise ValueError(f"Cannot downgrade v{self._base_metadata.format_version} table to v{update.format_version}") - if update.format_version == self._base_metadata.format_version: + if update.format_version < current_format_version: + raise ValueError(f"Cannot downgrade v{current_format_version} table to v{update.format_version}") + if update.format_version == current_format_version: return # At this point, the base_metadata is guaranteed to be v1 - if isinstance(self._base_metadata, TableMetadataV1): - self._base_metadata = self._base_metadata.to_v2() + self._base_metadata["format-version"] = 2 - raise ValueError(f"Cannot upgrade v{self._base_metadata.format_version} table to v{update.format_version}") + raise ValueError(f"Cannot upgrade v{current_format_version} table to v{update.format_version}") @update_table_metadata.register(AddSchemaUpdate) def _(self, update: AddSchemaUpdate) -> None: @@ -454,28 +458,31 @@ def _(self, update: SetCurrentSchemaUpdate) -> None: @update_table_metadata.register(AddSnapshotUpdate) def _(self, update: AddSnapshotUpdate) -> None: - if len(self._base_metadata.schemas) == 0: + if len(self._base_metadata["schemas"]) == 0: raise ValueError("Attempting to add a snapshot before a schema is added") - if len(self._base_metadata.partition_specs) == 0: + if len(self._base_metadata["partition-specs"]) == 0: raise ValueError("Attempting to add a snapshot before a partition spec is added") - if len(self._base_metadata.sort_orders) == 0: + if len(self._base_metadata["sort-orders"]) == 0: raise ValueError("Attempting to add a snapshot before a sort order is added") - if any(update.snapshot.snapshot_id == snapshot.snapshot_id for snapshot in self._base_metadata.snapshots): + if any( + update.snapshot.snapshot_id == Snapshot(**raw_snapshot).snapshot_id + for raw_snapshot in self._base_metadata["snapshots"] + ): raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") if ( - self._base_metadata.format_version == 2 + self._base_metadata["format-version"] == 2 and update.snapshot.sequence_number is not None - and self._base_metadata.last_sequence_number is not None - and update.snapshot.sequence_number <= self._base_metadata.last_sequence_number + and self._base_metadata["last-sequence-number"] is not None + and update.snapshot.sequence_number <= self._base_metadata["last-sequence-number"] and update.snapshot.parent_snapshot_id is not None ): raise ValueError( - f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} older than last sequence number {self._base_metadata.last_sequence_number}" + f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} older than last sequence number {self._base_metadata['last-sequence-number']}" ) - self._base_metadata.last_updated_ms = update.snapshot.timestamp - self._base_metadata.last_sequence_number = update.snapshot.sequence_number - self._base_metadata.snapshots.append(update.snapshot) + self._base_metadata["last-updated-ms"] = update.snapshot.timestamp_ms + self._base_metadata["last-sequence-number"] = update.snapshot.sequence_number + self._base_metadata["snapshots"].append(update.snapshot) self._updates.append(update) @update_table_metadata.register(SetSnapshotRefUpdate) @@ -501,12 +508,17 @@ def _(self, update: SetSnapshotRefUpdate) -> None: max_snapshot_age_ms=update.max_snapshot_age_ms, max_ref_age_ms=update.max_age_ref_ms, ) - existing_ref = self._base_metadata.refs.get(snapshot_ref.ref_name) + existing_ref = self._base_metadata["refs"].get(update.ref_name) if existing_ref is not None and existing_ref == snapshot_ref: return snapshot = next( - snapshot for snapshot in self._base_metadata.snapshots if snapshot.snapshot_id == snapshot_ref.snapshot_id + ( + Snapshot(**raw_snapshot) + for raw_snapshot in self._base_metadata["snapshots"] + if raw_snapshot["snapshot-id"] == snapshot_ref.snapshot_id + ), + None, ) if snapshot is None: raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") @@ -516,25 +528,25 @@ def _(self, update: SetSnapshotRefUpdate) -> None: for prev_update in self._updates if isinstance(self._updates, AddSnapshotUpdate) ): - self._base_metadata.last_updated_ms = snapshot.timestamp + self._base_metadata["last-updated-ms"] = snapshot.timestamp - if snapshot_ref.ref_name == MAIN_BRANCH: - self._base_metadata.current_snapshot_id = snapshot_ref.snapshot_id + if update.ref_name == MAIN_BRANCH: + self._base_metadata["current-snapshot-id"] = snapshot_ref.snapshot_id # TODO: double-check if the default value of TableMetadata make the timestamp too early - # if self._base_metadata.last_updated_ms is None: - # self._base_metadata.last_updated_ms = datetime_to_millis(datetime.datetime.now().astimezone()) - self._base_metadata.snapshot_log.append( + if self._base_metadata["last-updated-ms"] is None: + self._base_metadata["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) + self._base_metadata["snapshot-log"].append( SnapshotLogEntry( snapshot_id=snapshot_ref.snapshot_id, - timestamp_ms=self._base_metadata.last_updated_ms, - ) + timestamp_ms=self._base_metadata["last-updated-ms"], + ).model_dump() ) - self._base_metadata.refs[snapshot_ref.ref_name] = snapshot_ref + self._base_metadata["refs"][update.ref_name] = snapshot_ref self._updates.append(update) def build(self) -> TableMetadata: - return TableMetadataUtil.parse_obj(self._base_metadata.model_dump()) + return TableMetadataUtil.parse_obj(self._base_metadata) class TableRequirement(IcebergBaseModel): diff --git a/tests/table/test_init.py b/tests/table/test_init.py index a5a3e3e833..7699fe1a6e 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -37,12 +37,15 @@ from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import ( + AddSnapshotUpdate, SetPropertiesUpdate, + SetSnapshotRefUpdate, StaticTable, Table, UpdateSchema, _generate_snapshot_id, - _match_deletes_to_datafile, update_table_metadata, + _match_deletes_to_datafile, + update_table_metadata, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER from pyiceberg.table.snapshots import ( @@ -508,14 +511,64 @@ def test_add_nested_list_type_column(table: Table) -> None: ) assert new_schema.highest_field_id == 7 + def test_update_metadata_table_schema(table: Table) -> None: transaction = table.transaction() update = transaction.update_schema() update.add_column(path="b", field_type=IntegerType()) update.commit() - new_metadata = update_table_metadata(table.metadata, transaction._updates) # pylint: disable=W0212 - print(new_metadata) + apply_schema: Schema = next(schema for schema in new_metadata.schemas if schema.schema_id == 2) + assert len(apply_schema.fields) == 4 + + assert apply_schema == Schema( + NestedField(field_id=1, name="x", field_type=LongType(), required=True), + NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), + NestedField(field_id=3, name="z", field_type=LongType(), required=True), + NestedField(field_id=4, name="b", field_type=IntegerType(), required=False), + identifier_field_ids=[1, 2], + ) + assert apply_schema.schema_id == 2 + assert apply_schema.highest_field_id == 4 + + assert new_metadata.current_schema_id == 2 + + +def test_update_metadata_add_snapshot(table: Table) -> None: + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638573590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + ) + + new_metadata = update_table_metadata(table.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) + assert len(new_metadata.snapshots) == 3 + assert new_metadata.snapshots[2] == new_snapshot + assert new_metadata.last_sequence_number == new_snapshot.sequence_number + assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms + + +def test_update_metadata_set_snapshot_ref(table: Table) -> None: + update = SetSnapshotRefUpdate( + ref_name="main", + type="branch", + snapshot_id=3051729675574597004, + max_age_ref_ms=123123123, + max_snapshot_age_ms=12312312312, + min_snapshots_to_keep=1, + ) + + new_metadata = update_table_metadata(table.metadata, (update,)) + assert len(new_metadata.snapshot_log) == 3 + assert new_metadata.snapshot_log[2] == SnapshotLogEntry( + snapshot_id=3051729675574597004, timestamp_ms=table.metadata.last_updated_ms + ) + assert new_metadata.current_snapshot_id == 3051729675574597004 + assert new_metadata.last_updated_ms == table.metadata.last_updated_ms def test_generate_snapshot_id(table: Table) -> None: From c3e13119e7fb7279f842450ac33b11734b9a9624 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 00:22:28 -0800 Subject: [PATCH 03/22] implement apply_metadata_update which is simpler --- pyiceberg/table/__init__.py | 203 +++++++++++++++++++++++++++++++++++- 1 file changed, 199 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 203825fbae..430d2cc0f9 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -23,7 +23,7 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from functools import cached_property, singledispatchmethod +from functools import cached_property, singledispatch, singledispatchmethod from itertools import chain from typing import ( TYPE_CHECKING, @@ -357,6 +357,30 @@ class RemovePropertiesUpdate(TableUpdate): removals: List[str] +class TableMetadataUpdateContext: + updates: List[TableUpdate] + last_added_schema_id: Optional[int] + + def __init__(self) -> None: + self.updates = [] + self.last_added_schema_id = None + + def get_updates_by_action(self, update_type: TableUpdateAction) -> List[TableUpdate]: + return [update for update in self.updates if update.action == update_type] + + def is_added_snapshot(self, snapshot_id: int) -> bool: + return any( + update.snapshot.snapshot_id == snapshot_id + for update in self.updates + if update.action == TableUpdateAction.add_snapshot + ) + + def is_added_schema(self, schema_id: int) -> bool: + return any( + update.schema_.schema_id == schema_id for update in self.updates if update.action == TableUpdateAction.add_schema + ) + + class TableMetadataUpdateBuilder: _base_metadata: Dict[str, Any] _updates: List[TableUpdate] @@ -549,6 +573,176 @@ def build(self) -> TableMetadata: return TableMetadataUtil.parse_obj(self._base_metadata) +@singledispatch +def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: + raise ValueError(f"Unsupported update: {update}") + + +@apply_table_update.register(UpgradeFormatVersionUpdate) +def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: + current_format_version = base_metadata.format_version + if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: + raise ValueError(f"Unsupported table format version: {update.format_version}") + if update.format_version < current_format_version: + raise ValueError(f"Cannot downgrade v{current_format_version} table to v{update.format_version}") + if update.format_version == current_format_version: + return base_metadata + + if current_format_version == 1 and update.format_version == 2: + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["format-version"] = update.format_version + return TableMetadataUtil.parse_obj(updated_metadata_data) + + raise ValueError(f"Cannot upgrade v{current_format_version} table to v{update.format_version}") + + +@apply_table_update.register(AddSchemaUpdate) +def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: + def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: + # if the schema already exists, use its id; otherwise use the highest id + 1 + new_schema_id = base_metadata.current_schema_id + for schema in base_metadata.schemas: + if schema == new_schema: + return schema.schema_id, True + elif schema.schema_id >= new_schema_id: + new_schema_id = schema.schema_id + 1 + return new_schema_id, False + + if update.last_column_id < base_metadata.last_column_id: + raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") + new_schema_id, schema_found = reuse_or_create_new_schema_id(update.schema_) + if schema_found and update.last_column_id == base_metadata.last_column_id: + if context.last_added_schema_id is not None and context.is_added_schema(new_schema_id): + context.last_added_schema_id = new_schema_id + return base_metadata + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["last-column-id"] = update.last_column_id + + new_schema = ( + update.schema_ + if new_schema_id == update.schema_.schema_id + # TODO: double check the parameter passing here, schema.fields may be interpreted as the **data fileds + else Schema(*update.schema_.fields, schema_id=new_schema_id, identifier_field_ids=update.schema_.identifier_field_ids) + ) + + if not schema_found: + updated_metadata_data["schemas"].append(new_schema.model_dump()) + + context.updates.append(update) + context.last_added_schema_id = new_schema_id + return TableMetadataUtil.parse_obj(updated_metadata_data) + + +@apply_table_update.register(SetCurrentSchemaUpdate) +def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: + if update.schema_id == -1: + if context.last_added_schema_id is None: + raise ValueError("Cannot set current schema to last added schema when no schema has been added") + return apply_table_update(SetCurrentSchemaUpdate(schema_id=context.last_added_schema_id), base_metadata, context) + + if update.schema_id == base_metadata.current_schema_id: + return base_metadata + + schema = next((schema for schema in base_metadata.schemas if schema.schema_id == update.schema_id), None) + if schema is None: + raise ValueError(f"Schema with id {update.schema_id} does not exist") + + # TODO: rebuild sort_order and partition_spec + # So it seems the rebuild just refresh the inner field which hold the schema and some naming check for partition_spec + # Seems this is not necessary in pyiceberg case wince + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["current-schema-id"] = update.schema_id + if context.last_added_schema_id is not None and context.last_added_schema_id == update.schema_id: + context.updates.append(SetCurrentSchemaUpdate(schema_id=-1)) + else: + context.updates.append(update) + + return TableMetadataUtil.parse_obj(updated_metadata_data) + + +@apply_table_update.register(AddSnapshotUpdate) +def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: + if len(base_metadata.schemas) == 0: + raise ValueError("Attempting to add a snapshot before a schema is added") + if len(base_metadata.partition_specs) == 0: + raise ValueError("Attempting to add a snapshot before a partition spec is added") + if len(base_metadata.sort_orders) == 0: + raise ValueError("Attempting to add a snapshot before a sort order is added") + if any(update.snapshot.snapshot_id == snapshot.snapshot_id for snapshot in base_metadata.snapshots): + raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") + if ( + base_metadata.format_version == 2 + and update.snapshot.sequence_number is not None + and update.snapshot.sequence_number <= base_metadata.last_sequence_number + and update.snapshot.parent_snapshot_id is not None + ): + raise ValueError( + f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} older than last sequence number {base_metadata.last_sequence_number}" + ) + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms + updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number + updated_metadata_data["snapshots"].append(update.snapshot.model_dump()) + context.updates.append(update) + return TableMetadataUtil.parse_obj(updated_metadata_data) + + +@apply_table_update.register(SetSnapshotRefUpdate) +def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: + if update.type is None: + raise ValueError("Snapshot ref type must be set") + if update.min_snapshots_to_keep is not None and update.type == SnapshotRefType.TAG: + raise ValueError("Cannot set min snapshots to keep for branch refs") + if update.min_snapshots_to_keep is not None and update.min_snapshots_to_keep <= 0: + raise ValueError("Minimum snapshots to keep must be >= 0") + if update.max_snapshot_age_ms is not None and update.type == SnapshotRefType.TAG: + raise ValueError("Tags do not support setting maxSnapshotAgeMs") + if update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: + raise ValueError("Max snapshot age must be > 0 ms") + if update.max_age_ref_ms is not None and update.max_age_ref_ms <= 0: + raise ValueError("Max ref age must be > 0 ms") + snapshot_ref = SnapshotRef( + snapshot_id=update.snapshot_id, + snapshot_ref_type=update.type, + min_snapshots_to_keep=update.min_snapshots_to_keep, + max_snapshot_age_ms=update.max_snapshot_age_ms, + max_ref_age_ms=update.max_age_ref_ms, + ) + existing_ref = base_metadata.refs.get(update.ref_name) + if existing_ref is not None and existing_ref == snapshot_ref: + return base_metadata + + snapshot = next( + (snapshot for snapshot in base_metadata.snapshots if snapshot.snapshot_id == snapshot_ref.snapshot_id), + None, + ) + if snapshot is None: + raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") + + update_metadata_data = copy(base_metadata.model_dump()) + if context.is_added_snapshot(snapshot_ref.snapshot_id): + update_metadata_data["last-updated-ms"] = snapshot.timestamp + + if update.ref_name == MAIN_BRANCH: + update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id + # TODO: double-check if the default value of TableMetadata make the timestamp too early + # if base_metadata.last_updated_ms is None: + # update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) + update_metadata_data["snapshot-log"].append( + SnapshotLogEntry( + snapshot_id=snapshot_ref.snapshot_id, + timestamp_ms=update_metadata_data["last-updated-ms"], + ).model_dump() + ) + + update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump() + context.updates.append(update) + return TableMetadataUtil.parse_obj(update_metadata_data) + + class TableRequirement(IcebergBaseModel): type: str @@ -681,10 +875,11 @@ class CommitTableResponse(IcebergBaseModel): def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: - builder = TableMetadataUpdateBuilder(base_metadata) + context = TableMetadataUpdateContext() + new_metadata = base_metadata for update in updates: - builder.update_table_metadata(update) - return builder.build() + new_metadata = apply_table_update(update, new_metadata, context) + return new_metadata class Table: From 2b7a7d17aeecb929ae34ecd7ecf92d42982ed8f9 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 00:24:25 -0800 Subject: [PATCH 04/22] remove old implementation --- pyiceberg/table/__init__.py | 196 +----------------------------------- 1 file changed, 1 insertion(+), 195 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 430d2cc0f9..1c99302713 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,14 +16,13 @@ # under the License. from __future__ import annotations -import datetime import itertools import uuid from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass from enum import Enum -from functools import cached_property, singledispatch, singledispatchmethod +from functools import cached_property, singledispatch from itertools import chain from typing import ( TYPE_CHECKING, @@ -96,7 +95,6 @@ StructType, ) from pyiceberg.utils.concurrent import ExecutorFactory -from pyiceberg.utils.datetime import datetime_to_millis if TYPE_CHECKING: import pandas as pd @@ -381,198 +379,6 @@ def is_added_schema(self, schema_id: int) -> bool: ) -class TableMetadataUpdateBuilder: - _base_metadata: Dict[str, Any] - _updates: List[TableUpdate] - _last_added_schema_id: Optional[int] - - def __init__(self, base_metadata: TableMetadata) -> None: - self._base_metadata = copy(base_metadata.model_dump()) - self._updates = [] - self._last_added_schema_id = None - - def _reuse_or_create_new_schema_id(self, new_schema: Schema) -> Tuple[int, bool]: - # if the schema already exists, use its id; otherwise use the highest id + 1 - new_schema_id = self._base_metadata["current-schema-id"] - for raw_schema in self._base_metadata["schemas"]: - schema = Schema(**raw_schema) - if schema == new_schema: - return schema.schema_id, False - elif schema.schema_id >= new_schema_id: - new_schema_id = schema.schema_id + 1 - return new_schema_id, True - - def _add_schema_internal(self, schema: Schema, last_column_id: int, update: TableUpdate) -> int: - if last_column_id < self._base_metadata["last-column-id"]: - raise ValueError(f"Invalid last column id {last_column_id}, must be >= {self._base_metadata['last-column-id']}") - new_schema_id, is_new_schema = self._reuse_or_create_new_schema_id(schema) - if not is_new_schema and last_column_id == self._base_metadata["last-column-id"]: - if self._last_added_schema_id is not None and any( - update.schema_.schema_id == new_schema_id for update in self._updates if isinstance(update, AddSchemaUpdate) - ): - self._last_added_schema_id = new_schema_id - return new_schema_id - - self._base_metadata["last-column-id"] = last_column_id - - new_schema = ( - schema - if new_schema_id == schema.schema_id - # TODO: double check the parameter passing here, schema.fields may be interpreted as the **data fileds - else Schema(*schema.fields, schema_id=new_schema_id, identifier_field_ids=schema.identifier_field_ids) - ) - - if is_new_schema: - self._base_metadata["schemas"].append(new_schema.model_dump()) - - self._updates.append(update) - self._last_added_schema_id = new_schema_id - return new_schema_id - - def _set_current_schema(self, schema_id: int) -> None: - if schema_id == -1: - if self._last_added_schema_id is None: - raise ValueError("Cannot set current schema to last added schema when no schema has been added") - return self._set_current_schema(self._last_added_schema_id) - - if schema_id == self._base_metadata["current-schema-id"]: - return - - schema = next( - (Schema(**raw_schema) for raw_schema in self._base_metadata["schemas"] if raw_schema["schema-id"] == schema_id), None - ) - if schema is None: - raise ValueError(f"Schema with id {schema_id} does not exist") - - # TODO: rebuild sort_order and partition_spec - # So it seems the rebuild just refresh the inner field which hold the schema and some naming check for partition_spec - # Seems this is not necessary in pyiceberg case wince - - self._base_metadata["current-schema-id"] = schema_id - if self._last_added_schema_id is not None and self._last_added_schema_id == schema_id: - self._updates.append(SetCurrentSchemaUpdate(schema_id=-1)) - else: - self._updates.append(SetCurrentSchemaUpdate(schema_id=schema_id)) - - @singledispatchmethod - def update_table_metadata(self, update: TableUpdate) -> None: - raise TypeError(f"Unsupported update: {update}") - - @update_table_metadata.register(UpgradeFormatVersionUpdate) - def _(self, update: UpgradeFormatVersionUpdate) -> None: - current_format_version = self._base_metadata["format-version"] - if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: - raise ValueError(f"Unsupported table format version: {update.format_version}") - if update.format_version < current_format_version: - raise ValueError(f"Cannot downgrade v{current_format_version} table to v{update.format_version}") - if update.format_version == current_format_version: - return - # At this point, the base_metadata is guaranteed to be v1 - self._base_metadata["format-version"] = 2 - - raise ValueError(f"Cannot upgrade v{current_format_version} table to v{update.format_version}") - - @update_table_metadata.register(AddSchemaUpdate) - def _(self, update: AddSchemaUpdate) -> None: - self._add_schema_internal(update.schema_, update.last_column_id, update) - - @update_table_metadata.register(SetCurrentSchemaUpdate) - def _(self, update: SetCurrentSchemaUpdate) -> None: - self._set_current_schema(update.schema_id) - - @update_table_metadata.register(AddSnapshotUpdate) - def _(self, update: AddSnapshotUpdate) -> None: - if len(self._base_metadata["schemas"]) == 0: - raise ValueError("Attempting to add a snapshot before a schema is added") - if len(self._base_metadata["partition-specs"]) == 0: - raise ValueError("Attempting to add a snapshot before a partition spec is added") - if len(self._base_metadata["sort-orders"]) == 0: - raise ValueError("Attempting to add a snapshot before a sort order is added") - if any( - update.snapshot.snapshot_id == Snapshot(**raw_snapshot).snapshot_id - for raw_snapshot in self._base_metadata["snapshots"] - ): - raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") - if ( - self._base_metadata["format-version"] == 2 - and update.snapshot.sequence_number is not None - and self._base_metadata["last-sequence-number"] is not None - and update.snapshot.sequence_number <= self._base_metadata["last-sequence-number"] - and update.snapshot.parent_snapshot_id is not None - ): - raise ValueError( - f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} older than last sequence number {self._base_metadata['last-sequence-number']}" - ) - - self._base_metadata["last-updated-ms"] = update.snapshot.timestamp_ms - self._base_metadata["last-sequence-number"] = update.snapshot.sequence_number - self._base_metadata["snapshots"].append(update.snapshot) - self._updates.append(update) - - @update_table_metadata.register(SetSnapshotRefUpdate) - def _(self, update: SetSnapshotRefUpdate) -> None: - ## TODO: may be some of the validation could be added to SnapshotRef class - ## TODO: may be we need to make some of the field in this update as optional or we can remove some of the checks - if update.type is None: - raise ValueError("Snapshot ref type must be set") - if update.min_snapshots_to_keep is not None and update.type == SnapshotRefType.TAG: - raise ValueError("Cannot set min snapshots to keep for branch refs") - if update.min_snapshots_to_keep is not None and update.min_snapshots_to_keep <= 0: - raise ValueError("Minimum snapshots to keep must be >= 0") - if update.max_snapshot_age_ms is not None and update.type == SnapshotRefType.TAG: - raise ValueError("Tags do not support setting maxSnapshotAgeMs") - if update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: - raise ValueError("Max snapshot age must be > 0 ms") - if update.max_age_ref_ms is not None and update.max_age_ref_ms <= 0: - raise ValueError("Max ref age must be > 0 ms") - snapshot_ref = SnapshotRef( - snapshot_id=update.snapshot_id, - snapshot_ref_type=update.type, - min_snapshots_to_keep=update.min_snapshots_to_keep, - max_snapshot_age_ms=update.max_snapshot_age_ms, - max_ref_age_ms=update.max_age_ref_ms, - ) - existing_ref = self._base_metadata["refs"].get(update.ref_name) - if existing_ref is not None and existing_ref == snapshot_ref: - return - - snapshot = next( - ( - Snapshot(**raw_snapshot) - for raw_snapshot in self._base_metadata["snapshots"] - if raw_snapshot["snapshot-id"] == snapshot_ref.snapshot_id - ), - None, - ) - if snapshot is None: - raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") - - if any( - snapshot_ref.snapshot_id == prev_update.snapshot.snapshot_id - for prev_update in self._updates - if isinstance(self._updates, AddSnapshotUpdate) - ): - self._base_metadata["last-updated-ms"] = snapshot.timestamp - - if update.ref_name == MAIN_BRANCH: - self._base_metadata["current-snapshot-id"] = snapshot_ref.snapshot_id - # TODO: double-check if the default value of TableMetadata make the timestamp too early - if self._base_metadata["last-updated-ms"] is None: - self._base_metadata["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) - self._base_metadata["snapshot-log"].append( - SnapshotLogEntry( - snapshot_id=snapshot_ref.snapshot_id, - timestamp_ms=self._base_metadata["last-updated-ms"], - ).model_dump() - ) - - self._base_metadata["refs"][update.ref_name] = snapshot_ref - self._updates.append(update) - - def build(self) -> TableMetadata: - return TableMetadataUtil.parse_obj(self._base_metadata) - - @singledispatch def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: raise ValueError(f"Unsupported update: {update}") From 4fc25df27a3c974f9bf2d4dcbf940badce961dcf Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 00:26:13 -0800 Subject: [PATCH 05/22] re-organize method place --- pyiceberg/table/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 1c99302713..5a25edb283 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -549,6 +549,14 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table return TableMetadataUtil.parse_obj(update_metadata_data) +def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: + context = TableMetadataUpdateContext() + new_metadata = base_metadata + for update in updates: + new_metadata = apply_table_update(update, new_metadata, context) + return new_metadata + + class TableRequirement(IcebergBaseModel): type: str @@ -680,14 +688,6 @@ class CommitTableResponse(IcebergBaseModel): metadata_location: str = Field(alias="metadata-location") -def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: - context = TableMetadataUpdateContext() - new_metadata = base_metadata - for update in updates: - new_metadata = apply_table_update(update, new_metadata, context) - return new_metadata - - class Table: identifier: Identifier = Field() metadata: TableMetadata From facb43b89815abca72d2ba0417861082104630fb Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 00:27:53 -0800 Subject: [PATCH 06/22] fix nit --- pyiceberg/table/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5a25edb283..69e55e9774 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -406,13 +406,13 @@ def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: # if the schema already exists, use its id; otherwise use the highest id + 1 - new_schema_id = base_metadata.current_schema_id + result_schema_id = base_metadata.current_schema_id for schema in base_metadata.schemas: if schema == new_schema: return schema.schema_id, True - elif schema.schema_id >= new_schema_id: - new_schema_id = schema.schema_id + 1 - return new_schema_id, False + elif schema.schema_id >= result_schema_id: + result_schema_id = schema.schema_id + 1 + return result_schema_id, False if update.last_column_id < base_metadata.last_column_id: raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") From 116c6fdadbdc4becc7acd4e5c23079ba5afa007a Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 10:27:48 -0800 Subject: [PATCH 07/22] fix test --- pyiceberg/table/__init__.py | 6 +++--- tests/table/test_init.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 69e55e9774..833116beeb 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -325,7 +325,7 @@ class SetSnapshotRefUpdate(TableUpdate): ref_name: str = Field(alias="ref-name") type: Literal["tag", "branch"] snapshot_id: int = Field(alias="snapshot-id") - max_age_ref_ms: int = Field(alias="max-ref-age-ms") + max_ref_age_ms: int = Field(alias="max-ref-age-ms") max_snapshot_age_ms: int = Field(alias="max-snapshot-age-ms") min_snapshots_to_keep: int = Field(alias="min-snapshots-to-keep") @@ -508,14 +508,14 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table raise ValueError("Tags do not support setting maxSnapshotAgeMs") if update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: raise ValueError("Max snapshot age must be > 0 ms") - if update.max_age_ref_ms is not None and update.max_age_ref_ms <= 0: + if update.max_ref_age_ms is not None and update.max_ref_age_ms <= 0: raise ValueError("Max ref age must be > 0 ms") snapshot_ref = SnapshotRef( snapshot_id=update.snapshot_id, snapshot_ref_type=update.type, min_snapshots_to_keep=update.min_snapshots_to_keep, max_snapshot_age_ms=update.max_snapshot_age_ms, - max_ref_age_ms=update.max_age_ref_ms, + max_ref_age_ms=update.max_ref_age_ms, ) existing_ref = base_metadata.refs.get(update.ref_name) if existing_ref is not None and existing_ref == snapshot_ref: diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 7699fe1a6e..98f39408b6 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -45,7 +45,7 @@ UpdateSchema, _generate_snapshot_id, _match_deletes_to_datafile, - update_table_metadata, + update_table_metadata, SnapshotRef, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER from pyiceberg.table.snapshots import ( @@ -547,7 +547,7 @@ def test_update_metadata_add_snapshot(table: Table) -> None: new_metadata = update_table_metadata(table.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) assert len(new_metadata.snapshots) == 3 - assert new_metadata.snapshots[2] == new_snapshot + assert new_metadata.snapshots[-1] == new_snapshot assert new_metadata.last_sequence_number == new_snapshot.sequence_number assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms @@ -557,7 +557,7 @@ def test_update_metadata_set_snapshot_ref(table: Table) -> None: ref_name="main", type="branch", snapshot_id=3051729675574597004, - max_age_ref_ms=123123123, + max_ref_age_ms=123123123, max_snapshot_age_ms=12312312312, min_snapshots_to_keep=1, ) @@ -569,6 +569,13 @@ def test_update_metadata_set_snapshot_ref(table: Table) -> None: ) assert new_metadata.current_snapshot_id == 3051729675574597004 assert new_metadata.last_updated_ms == table.metadata.last_updated_ms + assert new_metadata.refs[update.ref_name] == SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type="branch", + min_snapshots_to_keep=1, + max_snapshot_age_ms=12312312312, + max_ref_age_ms=123123123, + ) def test_generate_snapshot_id(table: Table) -> None: From 66a4f46352b70421fecfceaf46f720e2c4b8c938 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 17:25:45 -0800 Subject: [PATCH 08/22] add another test --- pyiceberg/table/__init__.py | 57 +++++++++++++++++++++++-------------- pyiceberg/table/metadata.py | 8 ++++++ tests/table/test_init.py | 32 ++++++++++++++++++++- 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 833116beeb..743479f709 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -386,26 +386,31 @@ def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, contex @apply_table_update.register(UpgradeFormatVersionUpdate) def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: - current_format_version = base_metadata.format_version if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: raise ValueError(f"Unsupported table format version: {update.format_version}") - if update.format_version < current_format_version: - raise ValueError(f"Cannot downgrade v{current_format_version} table to v{update.format_version}") - if update.format_version == current_format_version: - return base_metadata - if current_format_version == 1 and update.format_version == 2: - updated_metadata_data = copy(base_metadata.model_dump()) - updated_metadata_data["format-version"] = update.format_version - return TableMetadataUtil.parse_obj(updated_metadata_data) + if update.format_version < base_metadata.format_version: + raise ValueError(f"Cannot downgrade v{base_metadata.format_version} table to v{update.format_version}") - raise ValueError(f"Cannot upgrade v{current_format_version} table to v{update.format_version}") + if update.format_version == base_metadata.format_version: + return base_metadata + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["format-version"] = update.format_version + return TableMetadataUtil.parse_obj(updated_metadata_data) @apply_table_update.register(AddSchemaUpdate) def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: - # if the schema already exists, use its id; otherwise use the highest id + 1 + """Reuse schema id if schema already exists, otherwise create a new one. + + Args: + new_schema: The new schema to be added. + + Returns: + The new schema id and whether the schema already exists. + """ result_schema_id = base_metadata.current_schema_id for schema in base_metadata.schemas: if schema == new_schema: @@ -416,6 +421,7 @@ def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: if update.last_column_id < base_metadata.last_column_id: raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") + new_schema_id, schema_found = reuse_or_create_new_schema_id(update.schema_) if schema_found and update.last_column_id == base_metadata.last_column_id: if context.last_added_schema_id is not None and context.is_added_schema(new_schema_id): @@ -428,7 +434,6 @@ def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: new_schema = ( update.schema_ if new_schema_id == update.schema_.schema_id - # TODO: double check the parameter passing here, schema.fields may be interpreted as the **data fileds else Schema(*update.schema_.fields, schema_id=new_schema_id, identifier_field_ids=update.schema_.identifier_field_ids) ) @@ -450,7 +455,7 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: Tab if update.schema_id == base_metadata.current_schema_id: return base_metadata - schema = next((schema for schema in base_metadata.schemas if schema.schema_id == update.schema_id), None) + schema = base_metadata.schema_by_id(update.schema_id) if schema is None: raise ValueError(f"Schema with id {update.schema_id} does not exist") @@ -460,6 +465,7 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: Tab updated_metadata_data = copy(base_metadata.model_dump()) updated_metadata_data["current-schema-id"] = update.schema_id + if context.last_added_schema_id is not None and context.last_added_schema_id == update.schema_id: context.updates.append(SetCurrentSchemaUpdate(schema_id=-1)) else: @@ -472,12 +478,16 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: Tab def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: if len(base_metadata.schemas) == 0: raise ValueError("Attempting to add a snapshot before a schema is added") + if len(base_metadata.partition_specs) == 0: raise ValueError("Attempting to add a snapshot before a partition spec is added") + if len(base_metadata.sort_orders) == 0: raise ValueError("Attempting to add a snapshot before a sort order is added") - if any(update.snapshot.snapshot_id == snapshot.snapshot_id for snapshot in base_metadata.snapshots): + + if base_metadata.snapshot_by_id(update.snapshot.snapshot_id) is not None: raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") + if ( base_metadata.format_version == 2 and update.snapshot.sequence_number is not None @@ -500,16 +510,22 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: TableMet def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: if update.type is None: raise ValueError("Snapshot ref type must be set") + if update.min_snapshots_to_keep is not None and update.type == SnapshotRefType.TAG: raise ValueError("Cannot set min snapshots to keep for branch refs") + if update.min_snapshots_to_keep is not None and update.min_snapshots_to_keep <= 0: raise ValueError("Minimum snapshots to keep must be >= 0") + if update.max_snapshot_age_ms is not None and update.type == SnapshotRefType.TAG: raise ValueError("Tags do not support setting maxSnapshotAgeMs") + if update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: raise ValueError("Max snapshot age must be > 0 ms") + if update.max_ref_age_ms is not None and update.max_ref_age_ms <= 0: raise ValueError("Max ref age must be > 0 ms") + snapshot_ref = SnapshotRef( snapshot_id=update.snapshot_id, snapshot_ref_type=update.type, @@ -517,14 +533,12 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table max_snapshot_age_ms=update.max_snapshot_age_ms, max_ref_age_ms=update.max_ref_age_ms, ) + existing_ref = base_metadata.refs.get(update.ref_name) if existing_ref is not None and existing_ref == snapshot_ref: return base_metadata - snapshot = next( - (snapshot for snapshot in base_metadata.snapshots if snapshot.snapshot_id == snapshot_ref.snapshot_id), - None, - ) + snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id) if snapshot is None: raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") @@ -552,8 +566,10 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: context = TableMetadataUpdateContext() new_metadata = base_metadata + for update in updates: new_metadata = apply_table_update(update, new_metadata, context) + return new_metadata @@ -800,10 +816,7 @@ def current_snapshot(self) -> Optional[Snapshot]: def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: """Get the snapshot of this table with the given id, or None if there is no matching snapshot.""" - try: - return next(snapshot for snapshot in self.metadata.snapshots if snapshot.snapshot_id == snapshot_id) - except StopIteration: - return None + return self.metadata.snapshot_by_id(snapshot_id) # pylint: disable=W0212 def snapshot_by_name(self, name: str) -> Optional[Snapshot]: """Return the snapshot referenced by the given name or null if no such reference exists.""" diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 271d40e25a..91a0dfad57 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -218,6 +218,14 @@ class TableMetadataCommonFields(IcebergBaseModel): There is always a main branch reference pointing to the current-snapshot-id even if the refs map is null.""" + def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: + """Get the snapshot by id.""" + return next((snapshot for snapshot in self.snapshots if snapshot.snapshot_id == snapshot_id), None) + + def schema_by_id(self, schema_id: int) -> Optional[Schema]: + """Get the schema by id.""" + return next((schema for schema in self.schemas if schema.schema_id == schema_id), None) + class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel): """Represents version 1 of the Table Metadata. diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 98f39408b6..e0a545cbe9 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -40,12 +40,15 @@ AddSnapshotUpdate, SetPropertiesUpdate, SetSnapshotRefUpdate, + SnapshotRef, StaticTable, Table, + TableMetadataUpdateContext, UpdateSchema, _generate_snapshot_id, _match_deletes_to_datafile, - update_table_metadata, SnapshotRef, + apply_table_update, + update_table_metadata, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER from pyiceberg.table.snapshots import ( @@ -512,6 +515,33 @@ def test_add_nested_list_type_column(table: Table) -> None: assert new_schema.highest_field_id == 7 +def test_apply_add_schema_update(table: Table) -> None: + transaction = table.transaction() + update = transaction.update_schema() + update.add_column(path="b", field_type=IntegerType()) + update.commit() + + test_context = TableMetadataUpdateContext() + + new_table_metadata = apply_table_update( + transaction._updates[0], base_metadata=table.metadata, context=test_context + ) # pylint: disable=W0212 + assert len(new_table_metadata.schemas) == 3 + assert new_table_metadata.current_schema_id == 1 + assert len(test_context.updates) == 1 + assert test_context.updates[0] == transaction._updates[0] # pylint: disable=W0212 + assert test_context.last_added_schema_id == 2 + + new_table_metadata = apply_table_update( + transaction._updates[1], base_metadata=new_table_metadata, context=test_context + ) # pylint: disable=W0212 + assert len(new_table_metadata.schemas) == 3 + assert new_table_metadata.current_schema_id == 2 + assert len(test_context.updates) == 2 + assert test_context.updates[1] == transaction._updates[1] # pylint: disable=W0212 + assert test_context.last_added_schema_id == 2 + + def test_update_metadata_table_schema(table: Table) -> None: transaction = table.transaction() update = transaction.update_schema() From 2882d0db4815831ff2e40a0fbf297500635e95e8 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 22:32:03 -0800 Subject: [PATCH 09/22] clear TODO --- pyiceberg/table/__init__.py | 46 ++++++++++++++++++++++++++----------- pyiceberg/table/metadata.py | 15 +++++++----- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 743479f709..c9de95bdb3 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import datetime import itertools import uuid from abc import ABC, abstractmethod @@ -95,6 +96,7 @@ StructType, ) from pyiceberg.utils.concurrent import ExecutorFactory +from pyiceberg.utils.datetime import datetime_to_millis if TYPE_CHECKING: import pandas as pd @@ -381,7 +383,18 @@ def is_added_schema(self, schema_id: int) -> bool: @singledispatch def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: - raise ValueError(f"Unsupported update: {update}") + """Apply a table update to the table metadata. + + Args: + update: The update to be applied. + base_metadata: The base metadata to be updated. + context: Contains previous updates, last_added_snapshot_id and other change tracking information in the current transaction. + + Returns: + The updated metadata. + + """ + raise NotImplementedError(f"Unsupported table update: {update}") @apply_table_update.register(UpgradeFormatVersionUpdate) @@ -455,14 +468,10 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: Tab if update.schema_id == base_metadata.current_schema_id: return base_metadata - schema = base_metadata.schema_by_id(update.schema_id) + schema = base_metadata.schemas_by_id.get(update.schema_id) if schema is None: raise ValueError(f"Schema with id {update.schema_id} does not exist") - # TODO: rebuild sort_order and partition_spec - # So it seems the rebuild just refresh the inner field which hold the schema and some naming check for partition_spec - # Seems this is not necessary in pyiceberg case wince - updated_metadata_data = copy(base_metadata.model_dump()) updated_metadata_data["current-schema-id"] = update.schema_id @@ -485,7 +494,7 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: TableMet if len(base_metadata.sort_orders) == 0: raise ValueError("Attempting to add a snapshot before a sort order is added") - if base_metadata.snapshot_by_id(update.snapshot.snapshot_id) is not None: + if base_metadata.snapshots_by_id.get(update.snapshot.snapshot_id) is not None: raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") if ( @@ -495,7 +504,8 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: TableMet and update.snapshot.parent_snapshot_id is not None ): raise ValueError( - f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} older than last sequence number {base_metadata.last_sequence_number}" + f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} " + f"older than last sequence number {base_metadata.last_sequence_number}" ) updated_metadata_data = copy(base_metadata.model_dump()) @@ -538,19 +548,20 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table if existing_ref is not None and existing_ref == snapshot_ref: return base_metadata - snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id) + snapshot = base_metadata.snapshots_by_id.get(snapshot_ref.snapshot_id) if snapshot is None: raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") update_metadata_data = copy(base_metadata.model_dump()) + update_last_updated_ms = True if context.is_added_snapshot(snapshot_ref.snapshot_id): update_metadata_data["last-updated-ms"] = snapshot.timestamp + update_last_updated_ms = False if update.ref_name == MAIN_BRANCH: update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id - # TODO: double-check if the default value of TableMetadata make the timestamp too early - # if base_metadata.last_updated_ms is None: - # update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) + if update_last_updated_ms: + update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) update_metadata_data["snapshot-log"].append( SnapshotLogEntry( snapshot_id=snapshot_ref.snapshot_id, @@ -564,6 +575,15 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: + """Update the table metadata with the given updates in one transaction. + + Args: + base_metadata: The base metadata to be updated. + updates: The updates in one transaction. + + Returns: + The updated metadata. + """ context = TableMetadataUpdateContext() new_metadata = base_metadata @@ -816,7 +836,7 @@ def current_snapshot(self) -> Optional[Snapshot]: def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: """Get the snapshot of this table with the given id, or None if there is no matching snapshot.""" - return self.metadata.snapshot_by_id(snapshot_id) # pylint: disable=W0212 + return self.metadata.snapshots_by_id.get(snapshot_id) def snapshot_by_name(self, name: str) -> Optional[Snapshot]: """Return the snapshot referenced by the given name or null if no such reference exists.""" diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 91a0dfad57..cb3799bf1c 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -19,6 +19,7 @@ import datetime import uuid from copy import copy +from functools import cached_property from typing import ( Any, Dict, @@ -218,13 +219,15 @@ class TableMetadataCommonFields(IcebergBaseModel): There is always a main branch reference pointing to the current-snapshot-id even if the refs map is null.""" - def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: - """Get the snapshot by id.""" - return next((snapshot for snapshot in self.snapshots if snapshot.snapshot_id == snapshot_id), None) + @cached_property + def snapshots_by_id(self) -> Dict[int, Snapshot]: + """Index the snapshots by snapshot_id.""" + return {snapshot.snapshot_id: snapshot for snapshot in self.snapshots} - def schema_by_id(self, schema_id: int) -> Optional[Schema]: - """Get the schema by id.""" - return next((schema for schema in self.schemas if schema.schema_id == schema_id), None) + @cached_property + def schemas_by_id(self) -> Dict[int, Schema]: + """Index the schemas by schema_id.""" + return {schema.schema_id: schema for schema in self.schemas} class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel): From 8a8d4ffed2f230c387c06e627692066deefc8072 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 23:35:37 -0800 Subject: [PATCH 10/22] add a combined test --- pyiceberg/table/__init__.py | 7 +- tests/conftest.py | 42 ++++++- tests/table/test_init.py | 209 ++++++++++++++++++++++------------- tests/table/test_metadata.py | 25 ----- 4 files changed, 178 insertions(+), 105 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index c9de95bdb3..5473561e76 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -365,9 +365,6 @@ def __init__(self) -> None: self.updates = [] self.last_added_schema_id = None - def get_updates_by_action(self, update_type: TableUpdateAction) -> List[TableUpdate]: - return [update for update in self.updates if update.action == update_type] - def is_added_snapshot(self, snapshot_id: int) -> bool: return any( update.snapshot.snapshot_id == snapshot_id @@ -410,6 +407,8 @@ def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: updated_metadata_data = copy(base_metadata.model_dump()) updated_metadata_data["format-version"] = update.format_version + + context.updates.append(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @@ -555,7 +554,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table update_metadata_data = copy(base_metadata.model_dump()) update_last_updated_ms = True if context.is_added_snapshot(snapshot_ref.snapshot_id): - update_metadata_data["last-updated-ms"] = snapshot.timestamp + update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms update_last_updated_ms = False if update.ref_name == MAIN_BRANCH: diff --git a/tests/conftest.py b/tests/conftest.py index 79c01dc747..68a10be6b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,7 +73,7 @@ from pyiceberg.schema import Accessor, Schema from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table -from pyiceberg.table.metadata import TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2 from pyiceberg.types import ( BinaryType, BooleanType, @@ -353,6 +353,32 @@ def all_avro_types() -> Dict[str, Any]: } +EXAMPLE_TABLE_METADATA_V1 = { + "format-version": 1, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], +} + + +@pytest.fixture(scope="session") +def example_table_metadata_v1() -> Dict[str, Any]: + return EXAMPLE_TABLE_METADATA_V1 + + EXAMPLE_TABLE_METADATA_V2 = { "format-version": 2, "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", @@ -1651,7 +1677,19 @@ def example_task(data_file: str) -> FileScanTask: @pytest.fixture -def table(example_table_metadata_v2: Dict[str, Any]) -> Table: +def table_v1(example_table_metadata_v1: Dict[str, Any]) -> Table: + table_metadata = TableMetadataV1(**example_table_metadata_v1) + return Table( + identifier=("database", "table"), + metadata=table_metadata, + metadata_location=f"{table_metadata.location}/uuid.metadata.json", + io=load_file_io(), + catalog=NoopCatalog("NoopCatalog"), + ) + + +@pytest.fixture +def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table: table_metadata = TableMetadataV2(**example_table_metadata_v2) return Table( identifier=("database", "table"), diff --git a/tests/table/test_init.py b/tests/table/test_init.py index e0a545cbe9..e9522108eb 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -85,8 +85,8 @@ ) -def test_schema(table: Table) -> None: - assert table.schema() == Schema( +def test_schema(table_v2: Table) -> None: + assert table_v2.schema() == Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), NestedField(field_id=3, name="z", field_type=LongType(), required=True), @@ -95,8 +95,8 @@ def test_schema(table: Table) -> None: ) -def test_schemas(table: Table) -> None: - assert table.schemas() == { +def test_schemas(table_v2: Table) -> None: + assert table_v2.schemas() == { 0: Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), schema_id=0, @@ -112,20 +112,20 @@ def test_schemas(table: Table) -> None: } -def test_spec(table: Table) -> None: - assert table.spec() == PartitionSpec( +def test_spec(table_v2: Table) -> None: + assert table_v2.spec() == PartitionSpec( PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="x"), spec_id=0 ) -def test_specs(table: Table) -> None: - assert table.specs() == { +def test_specs(table_v2: Table) -> None: + assert table_v2.specs() == { 0: PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="x"), spec_id=0) } -def test_sort_order(table: Table) -> None: - assert table.sort_order() == SortOrder( +def test_sort_order(table_v2: Table) -> None: + assert table_v2.sort_order() == SortOrder( SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), SortField( source_id=3, @@ -137,8 +137,8 @@ def test_sort_order(table: Table) -> None: ) -def test_sort_orders(table: Table) -> None: - assert table.sort_orders() == { +def test_sort_orders(table_v2: Table) -> None: + assert table_v2.sort_orders() == { 3: SortOrder( SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), SortField( @@ -152,12 +152,12 @@ def test_sort_orders(table: Table) -> None: } -def test_location(table: Table) -> None: - assert table.location() == "s3://bucket/test/location" +def test_location(table_v2: Table) -> None: + assert table_v2.location() == "s3://bucket/test/location" -def test_current_snapshot(table: Table) -> None: - assert table.current_snapshot() == Snapshot( +def test_current_snapshot(table_v2: Table) -> None: + assert table_v2.current_snapshot() == Snapshot( snapshot_id=3055729675574597004, parent_snapshot_id=3051729675574597004, sequence_number=1, @@ -168,8 +168,8 @@ def test_current_snapshot(table: Table) -> None: ) -def test_snapshot_by_id(table: Table) -> None: - assert table.snapshot_by_id(3055729675574597004) == Snapshot( +def test_snapshot_by_id(table_v2: Table) -> None: + assert table_v2.snapshot_by_id(3055729675574597004) == Snapshot( snapshot_id=3055729675574597004, parent_snapshot_id=3051729675574597004, sequence_number=1, @@ -180,12 +180,12 @@ def test_snapshot_by_id(table: Table) -> None: ) -def test_snapshot_by_id_does_not_exist(table: Table) -> None: - assert table.snapshot_by_id(-1) is None +def test_snapshot_by_id_does_not_exist(table_v2: Table) -> None: + assert table_v2.snapshot_by_id(-1) is None -def test_snapshot_by_name(table: Table) -> None: - assert table.snapshot_by_name("test") == Snapshot( +def test_snapshot_by_name(table_v2: Table) -> None: + assert table_v2.snapshot_by_name("test") == Snapshot( snapshot_id=3051729675574597004, parent_snapshot_id=None, sequence_number=0, @@ -196,11 +196,11 @@ def test_snapshot_by_name(table: Table) -> None: ) -def test_snapshot_by_name_does_not_exist(table: Table) -> None: - assert table.snapshot_by_name("doesnotexist") is None +def test_snapshot_by_name_does_not_exist(table_v2: Table) -> None: + assert table_v2.snapshot_by_name("doesnotexist") is None -def test_repr(table: Table) -> None: +def test_repr(table_v2: Table) -> None: expected = """table( 1: x: required long, 2: y: required long (comment), @@ -209,37 +209,37 @@ def test_repr(table: Table) -> None: partition by: [x], sort order: [2 ASC NULLS FIRST, bucket[4](3) DESC NULLS LAST], snapshot: Operation.APPEND: id=3055729675574597004, parent_id=3051729675574597004, schema_id=1""" - assert repr(table) == expected + assert repr(table_v2) == expected -def test_history(table: Table) -> None: - assert table.history() == [ +def test_history(table_v2: Table) -> None: + assert table_v2.history() == [ SnapshotLogEntry(snapshot_id=3051729675574597004, timestamp_ms=1515100955770), SnapshotLogEntry(snapshot_id=3055729675574597004, timestamp_ms=1555100955770), ] -def test_table_scan_select(table: Table) -> None: - scan = table.scan() +def test_table_scan_select(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.selected_fields == ("*",) assert scan.select("a", "b").selected_fields == ("a", "b") assert scan.select("a", "c").select("a").selected_fields == ("a",) -def test_table_scan_row_filter(table: Table) -> None: - scan = table.scan() +def test_table_scan_row_filter(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.row_filter == AlwaysTrue() assert scan.filter(EqualTo("x", 10)).row_filter == EqualTo("x", 10) assert scan.filter(EqualTo("x", 10)).filter(In("y", (10, 11))).row_filter == And(EqualTo("x", 10), In("y", (10, 11))) -def test_table_scan_ref(table: Table) -> None: - scan = table.scan() +def test_table_scan_ref(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.use_ref("test").snapshot_id == 3051729675574597004 -def test_table_scan_ref_does_not_exists(table: Table) -> None: - scan = table.scan() +def test_table_scan_ref_does_not_exists(table_v2: Table) -> None: + scan = table_v2.scan() with pytest.raises(ValueError) as exc_info: _ = scan.use_ref("boom") @@ -247,8 +247,8 @@ def test_table_scan_ref_does_not_exists(table: Table) -> None: assert "Cannot scan unknown ref=boom" in str(exc_info.value) -def test_table_scan_projection_full_schema(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_full_schema(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.select("x", "y", "z").projection() == Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), @@ -258,8 +258,8 @@ def test_table_scan_projection_full_schema(table: Table) -> None: ) -def test_table_scan_projection_single_column(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_single_column(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.select("y").projection() == Schema( NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), schema_id=1, @@ -267,8 +267,8 @@ def test_table_scan_projection_single_column(table: Table) -> None: ) -def test_table_scan_projection_single_column_case_sensitive(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_single_column_case_sensitive(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.with_case_sensitive(False).select("Y").projection() == Schema( NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), schema_id=1, @@ -276,8 +276,8 @@ def test_table_scan_projection_single_column_case_sensitive(table: Table) -> Non ) -def test_table_scan_projection_unknown_column(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_unknown_column(table_v2: Table) -> None: + scan = table_v2.scan() with pytest.raises(ValueError) as exc_info: _ = scan.select("a").projection() @@ -285,16 +285,16 @@ def test_table_scan_projection_unknown_column(table: Table) -> None: assert "Could not find column: 'a'" in str(exc_info.value) -def test_static_table_same_as_table(table: Table, metadata_location: str) -> None: +def test_static_table_same_as_table(table_v2: Table, metadata_location: str) -> None: static_table = StaticTable.from_metadata(metadata_location) assert isinstance(static_table, Table) - assert static_table.metadata == table.metadata + assert static_table.metadata == table_v2.metadata -def test_static_table_gz_same_as_table(table: Table, metadata_location_gz: str) -> None: +def test_static_table_gz_same_as_table(table_v2: Table, metadata_location_gz: str) -> None: static_table = StaticTable.from_metadata(metadata_location_gz) assert isinstance(static_table, Table) - assert static_table.metadata == table.metadata + assert static_table.metadata == table_v2.metadata def test_static_table_io_does_not_exist(metadata_location: str) -> None: @@ -415,8 +415,8 @@ def test_serialize_set_properties_updates() -> None: assert SetPropertiesUpdate(updates={"abc": "🤪"}).model_dump_json() == """{"action":"set-properties","updates":{"abc":"🤪"}}""" -def test_add_column(table: Table) -> None: - update = UpdateSchema(table) +def test_add_column(table_v2: Table) -> None: + update = UpdateSchema(table_v2) update.add_column(path="b", field_type=IntegerType()) apply_schema: Schema = update._apply() # pylint: disable=W0212 assert len(apply_schema.fields) == 4 @@ -432,7 +432,7 @@ def test_add_column(table: Table) -> None: assert apply_schema.highest_field_id == 4 -def test_add_primitive_type_column(table: Table) -> None: +def test_add_primitive_type_column(table_v2: Table) -> None: primitive_type: Dict[str, PrimitiveType] = { "boolean": BooleanType(), "int": IntegerType(), @@ -450,7 +450,7 @@ def test_add_primitive_type_column(table: Table) -> None: for name, type_ in primitive_type.items(): field_name = f"new_column_{name}" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) update.add_column(path=field_name, field_type=type_, doc=f"new_column_{name}") new_schema = update._apply() # pylint: disable=W0212 @@ -459,10 +459,10 @@ def test_add_primitive_type_column(table: Table) -> None: assert field.doc == f"new_column_{name}" -def test_add_nested_type_column(table: Table) -> None: +def test_add_nested_type_column(table_v2: Table) -> None: # add struct type column field_name = "new_column_struct" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) struct_ = StructType( NestedField(1, "lat", DoubleType()), NestedField(2, "long", DoubleType()), @@ -477,10 +477,10 @@ def test_add_nested_type_column(table: Table) -> None: assert schema_.highest_field_id == 6 -def test_add_nested_map_type_column(table: Table) -> None: +def test_add_nested_map_type_column(table_v2: Table) -> None: # add map type column field_name = "new_column_map" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) map_ = MapType(1, StringType(), 2, IntegerType(), False) update.add_column(path=field_name, field_type=map_) new_schema = update._apply() # pylint: disable=W0212 @@ -489,10 +489,10 @@ def test_add_nested_map_type_column(table: Table) -> None: assert new_schema.highest_field_id == 6 -def test_add_nested_list_type_column(table: Table) -> None: +def test_add_nested_list_type_column(table_v2: Table) -> None: # add list type column field_name = "new_column_list" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) list_ = ListType( element_id=101, element_type=StructType( @@ -515,8 +515,8 @@ def test_add_nested_list_type_column(table: Table) -> None: assert new_schema.highest_field_id == 7 -def test_apply_add_schema_update(table: Table) -> None: - transaction = table.transaction() +def test_apply_add_schema_update(table_v2: Table) -> None: + transaction = table_v2.transaction() update = transaction.update_schema() update.add_column(path="b", field_type=IntegerType()) update.commit() @@ -524,7 +524,7 @@ def test_apply_add_schema_update(table: Table) -> None: test_context = TableMetadataUpdateContext() new_table_metadata = apply_table_update( - transaction._updates[0], base_metadata=table.metadata, context=test_context + transaction._updates[0], base_metadata=table_v2.metadata, context=test_context ) # pylint: disable=W0212 assert len(new_table_metadata.schemas) == 3 assert new_table_metadata.current_schema_id == 1 @@ -542,12 +542,12 @@ def test_apply_add_schema_update(table: Table) -> None: assert test_context.last_added_schema_id == 2 -def test_update_metadata_table_schema(table: Table) -> None: - transaction = table.transaction() +def test_update_metadata_table_schema(table_v2: Table) -> None: + transaction = table_v2.transaction() update = transaction.update_schema() update.add_column(path="b", field_type=IntegerType()) update.commit() - new_metadata = update_table_metadata(table.metadata, transaction._updates) # pylint: disable=W0212 + new_metadata = update_table_metadata(table_v2.metadata, transaction._updates) # pylint: disable=W0212 apply_schema: Schema = next(schema for schema in new_metadata.schemas if schema.schema_id == 2) assert len(apply_schema.fields) == 4 @@ -564,7 +564,7 @@ def test_update_metadata_table_schema(table: Table) -> None: assert new_metadata.current_schema_id == 2 -def test_update_metadata_add_snapshot(table: Table) -> None: +def test_update_metadata_add_snapshot(table_v2: Table) -> None: new_snapshot = Snapshot( snapshot_id=25, parent_snapshot_id=19, @@ -575,14 +575,14 @@ def test_update_metadata_add_snapshot(table: Table) -> None: schema_id=3, ) - new_metadata = update_table_metadata(table.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) + new_metadata = update_table_metadata(table_v2.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) assert len(new_metadata.snapshots) == 3 assert new_metadata.snapshots[-1] == new_snapshot assert new_metadata.last_sequence_number == new_snapshot.sequence_number assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms -def test_update_metadata_set_snapshot_ref(table: Table) -> None: +def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None: update = SetSnapshotRefUpdate( ref_name="main", type="branch", @@ -592,13 +592,11 @@ def test_update_metadata_set_snapshot_ref(table: Table) -> None: min_snapshots_to_keep=1, ) - new_metadata = update_table_metadata(table.metadata, (update,)) + new_metadata = update_table_metadata(table_v2.metadata, (update,)) assert len(new_metadata.snapshot_log) == 3 - assert new_metadata.snapshot_log[2] == SnapshotLogEntry( - snapshot_id=3051729675574597004, timestamp_ms=table.metadata.last_updated_ms - ) + assert new_metadata.snapshot_log[2].snapshot_id == 3051729675574597004 assert new_metadata.current_snapshot_id == 3051729675574597004 - assert new_metadata.last_updated_ms == table.metadata.last_updated_ms + assert new_metadata.last_updated_ms > table_v2.metadata.last_updated_ms assert new_metadata.refs[update.ref_name] == SnapshotRef( snapshot_id=3051729675574597004, snapshot_ref_type="branch", @@ -608,6 +606,69 @@ def test_update_metadata_set_snapshot_ref(table: Table) -> None: ) -def test_generate_snapshot_id(table: Table) -> None: +def test_update_metadata_with_multiple_updates(table_v1: Table) -> None: + base_metadata = table_v1.metadata + transaction = table_v1.transaction() + transaction.upgrade_table_version(format_version=2) + + schema_update_1 = transaction.update_schema() + schema_update_1.add_column(path="b", field_type=IntegerType()) + schema_update_1.commit() + + test_updates = transaction._updates # pylint: disable=W0212 + + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638573590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + ) + + test_updates += ( + AddSnapshotUpdate(snapshot=new_snapshot), + SetSnapshotRefUpdate( + ref_name="main", + type="branch", + snapshot_id=25, + max_ref_age_ms=123123123, + max_snapshot_age_ms=12312312312, + min_snapshots_to_keep=1, + ), + ) + + new_metadata = update_table_metadata(base_metadata, test_updates) + + # UpgradeFormatVersionUpdate + assert new_metadata.format_version == 2 + + # UpdateSchema + assert len(new_metadata.schemas) == 2 + assert new_metadata.current_schema_id == 1 + assert new_metadata.schemas_by_id[new_metadata.current_schema_id].highest_field_id == 4 + + # AddSchemaUpdate + assert len(new_metadata.snapshots) == 2 + assert new_metadata.snapshots[-1] == new_snapshot + assert new_metadata.last_sequence_number == new_snapshot.sequence_number + assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms + + # SetSnapshotRefUpdate + assert len(new_metadata.snapshot_log) == 1 + assert new_metadata.snapshot_log[0].snapshot_id == 25 + assert new_metadata.current_snapshot_id == 25 + assert new_metadata.last_updated_ms == 1602638573590 + assert new_metadata.refs["main"] == SnapshotRef( + snapshot_id=25, + snapshot_ref_type="branch", + min_snapshots_to_keep=1, + max_snapshot_age_ms=12312312312, + max_ref_age_ms=123123123, + ) + + +def test_generate_snapshot_id(table_v2: Table) -> None: assert isinstance(_generate_snapshot_id(), int) - assert isinstance(table.new_snapshot_id(), int) + assert isinstance(table_v2.new_snapshot_id(), int) diff --git a/tests/table/test_metadata.py b/tests/table/test_metadata.py index 2273843645..9f35dcbe8d 100644 --- a/tests/table/test_metadata.py +++ b/tests/table/test_metadata.py @@ -51,31 +51,6 @@ StructType, ) -EXAMPLE_TABLE_METADATA_V1 = { - "format-version": 1, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schema": { - "type": "struct", - "fields": [ - {"id": 1, "name": "x", "required": True, "type": "long"}, - {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": True, "type": "long"}, - ], - }, - "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], -} - - -@pytest.fixture(scope="session") -def example_table_metadata_v1() -> Dict[str, Any]: - return EXAMPLE_TABLE_METADATA_V1 - def test_from_dict_v1(example_table_metadata_v1: Dict[str, Any]) -> None: """Test initialization of a TableMetadata instance from a dictionary""" From 1cfe9d20b7a189f39d5d566f4290e46b2ca888d5 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 23:44:10 -0800 Subject: [PATCH 11/22] Fix merge conflict --- pyiceberg/table/__init__.py | 2 -- tests/conftest.py | 1 - 2 files changed, 3 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 6f659d05d1..f81b8e836c 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -77,8 +77,6 @@ TableMetadataUtil, ) from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType -from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata -from pyiceberg.table.refs import SnapshotRef from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import SortOrder from pyiceberg.typedef import ( diff --git a/tests/conftest.py b/tests/conftest.py index 3bb608fd53..11b371586c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,7 +74,6 @@ from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2 -from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.typedef import UTF8 from pyiceberg.types import ( BinaryType, From 8476d9b5e47026f921904f18a00f2eb4b787d96b Mon Sep 17 00:00:00 2001 From: HonahX Date: Sat, 11 Nov 2023 23:49:55 -0800 Subject: [PATCH 12/22] remove table requirement validation for PR simplification --- pyiceberg/table/__init__.py | 41 ------------------------------------- 1 file changed, 41 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index f81b8e836c..6d2a5125e5 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -595,21 +595,12 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda class TableRequirement(IcebergBaseModel): type: str - @abstractmethod - def validate(self, base_metadata: TableMetadata) -> None: - """Validate the requirement against the base metadata.""" - ... - class AssertCreate(TableRequirement): """The table must not already exist; used for create transactions.""" type: Literal["assert-create"] = Field(default="assert-create") - def validate(self, base_metadata: TableMetadata) -> None: - if base_metadata is not None: - raise ValueError("Table already exists") - class AssertTableUUID(TableRequirement): """The table UUID must match the requirement's `uuid`.""" @@ -617,10 +608,6 @@ class AssertTableUUID(TableRequirement): type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") uuid: str - def validate(self, base_metadata: TableMetadata) -> None: - if self.uuid != base_metadata.uuid: - raise ValueError(f"Table UUID does not match: {self.uuid} != {base_metadata.uuid}") - class AssertRefSnapshotId(TableRequirement): """The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`. @@ -632,19 +619,6 @@ class AssertRefSnapshotId(TableRequirement): ref: str snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") - def validate(self, base_metadata: TableMetadata) -> None: - snapshot_ref = base_metadata.refs.get(self.ref) - if snapshot_ref is not None: - ref_type = snapshot_ref.snapshot_ref_type - if self.snapshot_id is None: - raise ValueError(f"Requirement failed: {self.ref_tpe} {self.ref} was created concurrently") - elif self.snapshot_id != snapshot_ref.snapshot_id: - raise ValueError( - f"Requirement failed: {ref_type} {self.ref} has changed: expected id {self.snapshot_id}, found {snapshot_ref.snapshot_id}" - ) - elif self.snapshot_id is not None: - raise ValueError(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") - class AssertLastAssignedFieldId(TableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" @@ -652,9 +626,6 @@ class AssertLastAssignedFieldId(TableRequirement): type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") last_assigned_field_id: int = Field(..., alias="last-assigned-field-id") - def validate(self, base_metadata: TableMetadata) -> None: - raise NotImplementedError("Not yet implemented") - class AssertCurrentSchemaId(TableRequirement): """The table's current schema id must match the requirement's `current-schema-id`.""" @@ -662,9 +633,6 @@ class AssertCurrentSchemaId(TableRequirement): type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") current_schema_id: int = Field(..., alias="current-schema-id") - def validate(self, base_metadata: TableMetadata) -> None: - raise NotImplementedError("Not yet implemented") - class AssertLastAssignedPartitionId(TableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" @@ -672,9 +640,6 @@ class AssertLastAssignedPartitionId(TableRequirement): type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id") - def validate(self, base_metadata: TableMetadata) -> None: - raise NotImplementedError("Not yet implemented") - class AssertDefaultSpecId(TableRequirement): """The table's default spec id must match the requirement's `default-spec-id`.""" @@ -682,9 +647,6 @@ class AssertDefaultSpecId(TableRequirement): type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") default_spec_id: int = Field(..., alias="default-spec-id") - def validate(self, base_metadata: TableMetadata) -> None: - raise NotImplementedError("Not yet implemented") - class AssertDefaultSortOrderId(TableRequirement): """The table's default sort order id must match the requirement's `default-sort-order-id`.""" @@ -692,9 +654,6 @@ class AssertDefaultSortOrderId(TableRequirement): type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") default_sort_order_id: int = Field(..., alias="default-sort-order-id") - def validate(self, base_metadata: TableMetadata) -> None: - raise NotImplementedError("Not yet implemented") - class Namespace(IcebergRootModel[List[str]]): """Reference to one or more levels of a namespace.""" From 77c198c89db38d05813bad27c1c9b4cbda235071 Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 21 Nov 2023 21:28:47 -0800 Subject: [PATCH 13/22] make context private and solve elif issue --- pyiceberg/table/__init__.py | 66 ++++++++++++++++--------------------- tests/table/test_init.py | 12 +++---- 2 files changed, 34 insertions(+), 44 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 6d2a5125e5..c60b69f593 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -357,29 +357,29 @@ class RemovePropertiesUpdate(TableUpdate): removals: List[str] -class TableMetadataUpdateContext: - updates: List[TableUpdate] +class _TableMetadataUpdateContext: + _updates: List[TableUpdate] last_added_schema_id: Optional[int] def __init__(self) -> None: - self.updates = [] + self._updates = [] self.last_added_schema_id = None def is_added_snapshot(self, snapshot_id: int) -> bool: return any( update.snapshot.snapshot_id == snapshot_id - for update in self.updates + for update in self._updates if update.action == TableUpdateAction.add_snapshot ) def is_added_schema(self, schema_id: int) -> bool: return any( - update.schema_.schema_id == schema_id for update in self.updates if update.action == TableUpdateAction.add_schema + update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema ) @singledispatch -def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: +def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: """Apply a table update to the table metadata. Args: @@ -395,7 +395,7 @@ def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, contex @apply_table_update.register(UpgradeFormatVersionUpdate) -def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: +def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: raise ValueError(f"Unsupported table format version: {update.format_version}") @@ -408,12 +408,12 @@ def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: updated_metadata_data = copy(base_metadata.model_dump()) updated_metadata_data["format-version"] = update.format_version - context.updates.append(update) + context._updates.append(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @apply_table_update.register(AddSchemaUpdate) -def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: +def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: """Reuse schema id if schema already exists, otherwise create a new one. @@ -452,19 +452,18 @@ def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: if not schema_found: updated_metadata_data["schemas"].append(new_schema.model_dump()) - context.updates.append(update) + context._updates.append(update) context.last_added_schema_id = new_schema_id return TableMetadataUtil.parse_obj(updated_metadata_data) @apply_table_update.register(SetCurrentSchemaUpdate) -def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: +def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.schema_id == -1: if context.last_added_schema_id is None: raise ValueError("Cannot set current schema to last added schema when no schema has been added") return apply_table_update(SetCurrentSchemaUpdate(schema_id=context.last_added_schema_id), base_metadata, context) - - if update.schema_id == base_metadata.current_schema_id: + elif update.schema_id == base_metadata.current_schema_id: return base_metadata schema = base_metadata.schemas_by_id.get(update.schema_id) @@ -475,28 +474,24 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: Tab updated_metadata_data["current-schema-id"] = update.schema_id if context.last_added_schema_id is not None and context.last_added_schema_id == update.schema_id: - context.updates.append(SetCurrentSchemaUpdate(schema_id=-1)) + context._updates.append(SetCurrentSchemaUpdate(schema_id=-1)) else: - context.updates.append(update) + context._updates.append(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @apply_table_update.register(AddSnapshotUpdate) -def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: +def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if len(base_metadata.schemas) == 0: raise ValueError("Attempting to add a snapshot before a schema is added") - - if len(base_metadata.partition_specs) == 0: + elif len(base_metadata.partition_specs) == 0: raise ValueError("Attempting to add a snapshot before a partition spec is added") - - if len(base_metadata.sort_orders) == 0: + elif len(base_metadata.sort_orders) == 0: raise ValueError("Attempting to add a snapshot before a sort order is added") - - if base_metadata.snapshots_by_id.get(update.snapshot.snapshot_id) is not None: + elif base_metadata.snapshots_by_id.get(update.snapshot.snapshot_id) is not None: raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") - - if ( + elif ( base_metadata.format_version == 2 and update.snapshot.sequence_number is not None and update.snapshot.sequence_number <= base_metadata.last_sequence_number @@ -511,28 +506,23 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: TableMet updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number updated_metadata_data["snapshots"].append(update.snapshot.model_dump()) - context.updates.append(update) + context._updates.append(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @apply_table_update.register(SetSnapshotRefUpdate) -def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: TableMetadataUpdateContext) -> TableMetadata: +def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.type is None: raise ValueError("Snapshot ref type must be set") - - if update.min_snapshots_to_keep is not None and update.type == SnapshotRefType.TAG: + elif update.min_snapshots_to_keep is not None and update.type == SnapshotRefType.TAG: raise ValueError("Cannot set min snapshots to keep for branch refs") - - if update.min_snapshots_to_keep is not None and update.min_snapshots_to_keep <= 0: + elif update.min_snapshots_to_keep is not None and update.min_snapshots_to_keep <= 0: raise ValueError("Minimum snapshots to keep must be >= 0") - - if update.max_snapshot_age_ms is not None and update.type == SnapshotRefType.TAG: + elif update.max_snapshot_age_ms is not None and update.type == SnapshotRefType.TAG: raise ValueError("Tags do not support setting maxSnapshotAgeMs") - - if update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: + elif update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: raise ValueError("Max snapshot age must be > 0 ms") - - if update.max_ref_age_ms is not None and update.max_ref_age_ms <= 0: + elif update.max_ref_age_ms is not None and update.max_ref_age_ms <= 0: raise ValueError("Max ref age must be > 0 ms") snapshot_ref = SnapshotRef( @@ -569,7 +559,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: Table ) update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump() - context.updates.append(update) + context._updates.append(update) return TableMetadataUtil.parse_obj(update_metadata_data) @@ -583,7 +573,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda Returns: The updated metadata. """ - context = TableMetadataUpdateContext() + context = _TableMetadataUpdateContext() new_metadata = base_metadata for update in updates: diff --git a/tests/table/test_init.py b/tests/table/test_init.py index e9522108eb..c533a1f623 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -43,7 +43,7 @@ SnapshotRef, StaticTable, Table, - TableMetadataUpdateContext, + _TableMetadataUpdateContext, UpdateSchema, _generate_snapshot_id, _match_deletes_to_datafile, @@ -521,15 +521,15 @@ def test_apply_add_schema_update(table_v2: Table) -> None: update.add_column(path="b", field_type=IntegerType()) update.commit() - test_context = TableMetadataUpdateContext() + test_context = _TableMetadataUpdateContext() new_table_metadata = apply_table_update( transaction._updates[0], base_metadata=table_v2.metadata, context=test_context ) # pylint: disable=W0212 assert len(new_table_metadata.schemas) == 3 assert new_table_metadata.current_schema_id == 1 - assert len(test_context.updates) == 1 - assert test_context.updates[0] == transaction._updates[0] # pylint: disable=W0212 + assert len(test_context._updates) == 1 + assert test_context._updates[0] == transaction._updates[0] # pylint: disable=W0212 assert test_context.last_added_schema_id == 2 new_table_metadata = apply_table_update( @@ -537,8 +537,8 @@ def test_apply_add_schema_update(table_v2: Table) -> None: ) # pylint: disable=W0212 assert len(new_table_metadata.schemas) == 3 assert new_table_metadata.current_schema_id == 2 - assert len(test_context.updates) == 2 - assert test_context.updates[1] == transaction._updates[1] # pylint: disable=W0212 + assert len(test_context._updates) == 2 + assert test_context._updates[1] == transaction._updates[1] # pylint: disable=W0212 assert test_context.last_added_schema_id == 2 From be482ca88c7c514df5f6ac23d2e5893a5721bfb9 Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 21 Nov 2023 21:31:37 -0800 Subject: [PATCH 14/22] remove private field access --- pyiceberg/table/__init__.py | 15 +++++++++------ tests/table/test_init.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index c60b69f593..604277e748 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -365,6 +365,9 @@ def __init__(self) -> None: self._updates = [] self.last_added_schema_id = None + def add_update(self, update: TableUpdate) -> None: + self._updates.append(update) + def is_added_snapshot(self, snapshot_id: int) -> bool: return any( update.snapshot.snapshot_id == snapshot_id @@ -408,7 +411,7 @@ def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: updated_metadata_data = copy(base_metadata.model_dump()) updated_metadata_data["format-version"] = update.format_version - context._updates.append(update) + context.add_update(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @@ -452,7 +455,7 @@ def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: if not schema_found: updated_metadata_data["schemas"].append(new_schema.model_dump()) - context._updates.append(update) + context.add_update(update) context.last_added_schema_id = new_schema_id return TableMetadataUtil.parse_obj(updated_metadata_data) @@ -474,9 +477,9 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta updated_metadata_data["current-schema-id"] = update.schema_id if context.last_added_schema_id is not None and context.last_added_schema_id == update.schema_id: - context._updates.append(SetCurrentSchemaUpdate(schema_id=-1)) + context.add_update(SetCurrentSchemaUpdate(schema_id=-1)) else: - context._updates.append(update) + context.add_update(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @@ -506,7 +509,7 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number updated_metadata_data["snapshots"].append(update.snapshot.model_dump()) - context._updates.append(update) + context.add_update(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @@ -559,7 +562,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl ) update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump() - context._updates.append(update) + context.add_update(update) return TableMetadataUtil.parse_obj(update_metadata_data) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index c533a1f623..dbe334659a 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -43,10 +43,10 @@ SnapshotRef, StaticTable, Table, - _TableMetadataUpdateContext, UpdateSchema, _generate_snapshot_id, _match_deletes_to_datafile, + _TableMetadataUpdateContext, apply_table_update, update_table_metadata, ) From e2b085d4f6641fc5e81aa0862e8022e01ebbd9cb Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 21 Nov 2023 22:54:15 -0800 Subject: [PATCH 15/22] push snapshot ref validation to its builder using pydantic --- pyiceberg/table/__init__.py | 22 ++++----------- pyiceberg/table/refs.py | 22 ++++++++++++--- tests/table/test_refs.py | 54 +++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 21 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 604277e748..5c094e9283 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -42,6 +42,7 @@ from pydantic import Field, SerializeAsAny from sortedcontainers import SortedList +from typing_extensions import Annotated from pyiceberg.exceptions import ResolveError, ValidationError from pyiceberg.expressions import ( @@ -76,7 +77,7 @@ TableMetadata, TableMetadataUtil, ) -from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import SortOrder from pyiceberg.typedef import ( @@ -327,9 +328,9 @@ class SetSnapshotRefUpdate(TableUpdate): ref_name: str = Field(alias="ref-name") type: Literal["tag", "branch"] snapshot_id: int = Field(alias="snapshot-id") - max_ref_age_ms: int = Field(alias="max-ref-age-ms") - max_snapshot_age_ms: int = Field(alias="max-snapshot-age-ms") - min_snapshots_to_keep: int = Field(alias="min-snapshots-to-keep") + max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)] + max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)] + min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)] class RemoveSnapshotsUpdate(TableUpdate): @@ -515,19 +516,6 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe @apply_table_update.register(SetSnapshotRefUpdate) def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: - if update.type is None: - raise ValueError("Snapshot ref type must be set") - elif update.min_snapshots_to_keep is not None and update.type == SnapshotRefType.TAG: - raise ValueError("Cannot set min snapshots to keep for branch refs") - elif update.min_snapshots_to_keep is not None and update.min_snapshots_to_keep <= 0: - raise ValueError("Minimum snapshots to keep must be >= 0") - elif update.max_snapshot_age_ms is not None and update.type == SnapshotRefType.TAG: - raise ValueError("Tags do not support setting maxSnapshotAgeMs") - elif update.max_snapshot_age_ms is not None and update.max_snapshot_age_ms <= 0: - raise ValueError("Max snapshot age must be > 0 ms") - elif update.max_ref_age_ms is not None and update.max_ref_age_ms <= 0: - raise ValueError("Max ref age must be > 0 ms") - snapshot_ref = SnapshotRef( snapshot_id=update.snapshot_id, snapshot_ref_type=update.type, diff --git a/pyiceberg/table/refs.py b/pyiceberg/table/refs.py index b9692ca975..6f17880cac 100644 --- a/pyiceberg/table/refs.py +++ b/pyiceberg/table/refs.py @@ -17,8 +17,10 @@ from enum import Enum from typing import Optional -from pydantic import Field +from pydantic import Field, model_validator +from typing_extensions import Annotated +from pyiceberg.exceptions import ValidationError from pyiceberg.typedef import IcebergBaseModel MAIN_BRANCH = "main" @@ -36,6 +38,18 @@ def __repr__(self) -> str: class SnapshotRef(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") snapshot_ref_type: SnapshotRefType = Field(alias="type") - min_snapshots_to_keep: Optional[int] = Field(alias="min-snapshots-to-keep", default=None) - max_snapshot_age_ms: Optional[int] = Field(alias="max-snapshot-age-ms", default=None) - max_ref_age_ms: Optional[int] = Field(alias="max-ref-age-ms", default=None) + min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None, gt=0)] + max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None, gt=0)] + max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None, gt=0)] + + @model_validator(mode='after') + def check_min_snapshots_to_keep(self) -> 'SnapshotRef': + if self.min_snapshots_to_keep is not None and self.snapshot_ref_type == SnapshotRefType.TAG: + raise ValidationError("Tags do not support setting minSnapshotsToKeep") + return self + + @model_validator(mode='after') + def check_max_snapshot_age_ms(self) -> 'SnapshotRef': + if self.max_snapshot_age_ms is not None and self.snapshot_ref_type == SnapshotRefType.TAG: + raise ValidationError("Tags do not support setting maxSnapshotAgeMs") + return self diff --git a/tests/table/test_refs.py b/tests/table/test_refs.py index d106f0237a..e6b7006a99 100644 --- a/tests/table/test_refs.py +++ b/tests/table/test_refs.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=eval-used +import pytest +from pydantic import ValidationError + +from pyiceberg import exceptions from pyiceberg.table.refs import SnapshotRef, SnapshotRefType @@ -32,3 +36,53 @@ def test_snapshot_with_properties_repr() -> None: == """SnapshotRef(snapshot_id=3051729675574597004, snapshot_ref_type=SnapshotRefType.TAG, min_snapshots_to_keep=None, max_snapshot_age_ms=None, max_ref_age_ms=10000000)""" ) assert snapshot_ref == eval(repr(snapshot_ref)) + + +def test_snapshot_with_invalid_field() -> None: + # min_snapshots_to_keep, if present, must be greater than 0 + with pytest.raises(ValidationError): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=-1, + max_snapshot_age_ms=None, + max_ref_age_ms=10000000, + ) + + # max_snapshot_age_ms, if present, must be greater than 0 + with pytest.raises(ValidationError): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=1, + max_snapshot_age_ms=-1, + max_ref_age_ms=10000000, + ) + + # max_ref_age_ms, if present, must be greater than 0 + with pytest.raises(ValidationError): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=None, + max_snapshot_age_ms=None, + max_ref_age_ms=-1, + ) + + with pytest.raises(exceptions.ValidationError, match="Tags do not support setting minSnapshotsToKeep"): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=1, + max_snapshot_age_ms=None, + max_ref_age_ms=10000000, + ) + + with pytest.raises(exceptions.ValidationError, match="Tags do not support setting maxSnapshotAgeMs"): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=None, + max_snapshot_age_ms=1, + max_ref_age_ms=100000, + ) From 965b16daa3ec71cbd12f4b02abf16c6859602ad7 Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 21 Nov 2023 23:24:04 -0800 Subject: [PATCH 16/22] fix comment --- pyiceberg/table/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5c094e9283..c3fc256f6a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -562,7 +562,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda updates: The updates in one transaction. Returns: - The updated metadata. + The metadata with the updates applied. """ context = _TableMetadataUpdateContext() new_metadata = base_metadata From 53efa28831a85663f34b1b6768d3079b6fc1fb7a Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 21 Nov 2023 23:46:04 -0800 Subject: [PATCH 17/22] remove unnecessary code for AddSchemaUpdate update --- pyiceberg/table/__init__.py | 40 ++----------------------------------- 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index c3fc256f6a..2740e40691 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -376,11 +376,6 @@ def is_added_snapshot(self, snapshot_id: int) -> bool: if update.action == TableUpdateAction.add_snapshot ) - def is_added_schema(self, schema_id: int) -> bool: - return any( - update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema - ) - @singledispatch def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: @@ -418,46 +413,15 @@ def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: @apply_table_update.register(AddSchemaUpdate) def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: - def reuse_or_create_new_schema_id(new_schema: Schema) -> Tuple[int, bool]: - """Reuse schema id if schema already exists, otherwise create a new one. - - Args: - new_schema: The new schema to be added. - - Returns: - The new schema id and whether the schema already exists. - """ - result_schema_id = base_metadata.current_schema_id - for schema in base_metadata.schemas: - if schema == new_schema: - return schema.schema_id, True - elif schema.schema_id >= result_schema_id: - result_schema_id = schema.schema_id + 1 - return result_schema_id, False - if update.last_column_id < base_metadata.last_column_id: raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") - new_schema_id, schema_found = reuse_or_create_new_schema_id(update.schema_) - if schema_found and update.last_column_id == base_metadata.last_column_id: - if context.last_added_schema_id is not None and context.is_added_schema(new_schema_id): - context.last_added_schema_id = new_schema_id - return base_metadata - updated_metadata_data = copy(base_metadata.model_dump()) updated_metadata_data["last-column-id"] = update.last_column_id - - new_schema = ( - update.schema_ - if new_schema_id == update.schema_.schema_id - else Schema(*update.schema_.fields, schema_id=new_schema_id, identifier_field_ids=update.schema_.identifier_field_ids) - ) - - if not schema_found: - updated_metadata_data["schemas"].append(new_schema.model_dump()) + updated_metadata_data["schemas"].append(update.schema_.model_dump()) context.add_update(update) - context.last_added_schema_id = new_schema_id + context.last_added_schema_id = update.schema_.schema_id return TableMetadataUtil.parse_obj(updated_metadata_data) From b7fd063c9801f22de2d5b84be4836c8faff66b3d Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 21 Nov 2023 23:52:38 -0800 Subject: [PATCH 18/22] replace if with elif --- pyiceberg/table/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2740e40691..cc99eeb0b3 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -397,11 +397,9 @@ def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, contex def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: raise ValueError(f"Unsupported table format version: {update.format_version}") - - if update.format_version < base_metadata.format_version: + elif update.format_version < base_metadata.format_version: raise ValueError(f"Cannot downgrade v{base_metadata.format_version} table to v{update.format_version}") - - if update.format_version == base_metadata.format_version: + elif update.format_version == base_metadata.format_version: return base_metadata updated_metadata_data = copy(base_metadata.model_dump()) From bedd0cc85a39a28e05bf972da219de079488fb43 Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 3 Dec 2023 22:09:41 -0800 Subject: [PATCH 19/22] enhance the set current schema update implementation and some other changes --- pyiceberg/table/__init__.py | 34 +++++++++++++++++++--------------- pyiceberg/table/metadata.py | 15 ++++++--------- tests/table/test_init.py | 2 +- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc99eeb0b3..507a465076 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -376,6 +376,11 @@ def is_added_snapshot(self, snapshot_id: int) -> bool: if update.action == TableUpdateAction.add_snapshot ) + def is_added_schema(self, schema_id: int) -> bool: + return any( + update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema + ) + @singledispatch def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: @@ -425,25 +430,24 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta @apply_table_update.register(SetCurrentSchemaUpdate) def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: - if update.schema_id == -1: - if context.last_added_schema_id is None: + new_schema_id = update.schema_id + if new_schema_id == -1: + # The last added schema should be in base_metadata.schemas at this point + new_schema_id = max(schema.schema_id for schema in base_metadata.schemas) + if not context.is_added_schema(new_schema_id): raise ValueError("Cannot set current schema to last added schema when no schema has been added") - return apply_table_update(SetCurrentSchemaUpdate(schema_id=context.last_added_schema_id), base_metadata, context) - elif update.schema_id == base_metadata.current_schema_id: + + if update.schema_id == base_metadata.current_schema_id: return base_metadata - schema = base_metadata.schemas_by_id.get(update.schema_id) + schema = base_metadata.schema_by_id(new_schema_id) if schema is None: - raise ValueError(f"Schema with id {update.schema_id} does not exist") + raise ValueError(f"Schema with id {new_schema_id} does not exist") updated_metadata_data = copy(base_metadata.model_dump()) - updated_metadata_data["current-schema-id"] = update.schema_id - - if context.last_added_schema_id is not None and context.last_added_schema_id == update.schema_id: - context.add_update(SetCurrentSchemaUpdate(schema_id=-1)) - else: - context.add_update(update) + updated_metadata_data["current-schema-id"] = new_schema_id + context.add_update(update) return TableMetadataUtil.parse_obj(updated_metadata_data) @@ -455,7 +459,7 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe raise ValueError("Attempting to add a snapshot before a partition spec is added") elif len(base_metadata.sort_orders) == 0: raise ValueError("Attempting to add a snapshot before a sort order is added") - elif base_metadata.snapshots_by_id.get(update.snapshot.snapshot_id) is not None: + elif base_metadata.snapshot_by_id(update.snapshot.snapshot_id) is not None: raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") elif ( base_metadata.format_version == 2 @@ -490,7 +494,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl if existing_ref is not None and existing_ref == snapshot_ref: return base_metadata - snapshot = base_metadata.snapshots_by_id.get(snapshot_ref.snapshot_id) + snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id) if snapshot is None: raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") @@ -737,7 +741,7 @@ def current_snapshot(self) -> Optional[Snapshot]: def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: """Get the snapshot of this table with the given id, or None if there is no matching snapshot.""" - return self.metadata.snapshots_by_id.get(snapshot_id) + return self.metadata.snapshot_by_id(snapshot_id) def snapshot_by_name(self, name: str) -> Optional[Snapshot]: """Return the snapshot referenced by the given name or null if no such reference exists.""" diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index cb3799bf1c..43e29c7b03 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -19,7 +19,6 @@ import datetime import uuid from copy import copy -from functools import cached_property from typing import ( Any, Dict, @@ -219,15 +218,13 @@ class TableMetadataCommonFields(IcebergBaseModel): There is always a main branch reference pointing to the current-snapshot-id even if the refs map is null.""" - @cached_property - def snapshots_by_id(self) -> Dict[int, Snapshot]: - """Index the snapshots by snapshot_id.""" - return {snapshot.snapshot_id: snapshot for snapshot in self.snapshots} + def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: + """Get the snapshot by snapshot_id.""" + return next((snapshot for snapshot in self.snapshots if snapshot.snapshot_id == snapshot_id), None) - @cached_property - def schemas_by_id(self) -> Dict[int, Schema]: - """Index the schemas by schema_id.""" - return {schema.schema_id: schema for schema in self.schemas} + def schema_by_id(self, schema_id: int) -> Optional[Schema]: + """Get the schema by schema_id.""" + return next((schema for schema in self.schemas if schema.schema_id == schema_id), None) class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel): diff --git a/tests/table/test_init.py b/tests/table/test_init.py index dbe334659a..7154fcdf1a 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -647,7 +647,7 @@ def test_update_metadata_with_multiple_updates(table_v1: Table) -> None: # UpdateSchema assert len(new_metadata.schemas) == 2 assert new_metadata.current_schema_id == 1 - assert new_metadata.schemas_by_id[new_metadata.current_schema_id].highest_field_id == 4 + assert new_metadata.schema_by_id(new_metadata.current_schema_id).highest_field_id == 4 # type: ignore # AddSchemaUpdate assert len(new_metadata.snapshots) == 2 From aecc7c10ce9b95acca78c832c6f24d7100558d0f Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 3 Dec 2023 22:29:03 -0800 Subject: [PATCH 20/22] make apply_table_update private --- pyiceberg/table/__init__.py | 14 +++++++------- tests/table/test_init.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 28b7538c2c..07605252b9 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -383,7 +383,7 @@ def is_added_schema(self, schema_id: int) -> bool: @singledispatch -def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: +def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: """Apply a table update to the table metadata. Args: @@ -398,7 +398,7 @@ def apply_table_update(update: TableUpdate, base_metadata: TableMetadata, contex raise NotImplementedError(f"Unsupported table update: {update}") -@apply_table_update.register(UpgradeFormatVersionUpdate) +@_apply_table_update.register(UpgradeFormatVersionUpdate) def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: raise ValueError(f"Unsupported table format version: {update.format_version}") @@ -414,7 +414,7 @@ def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: return TableMetadataUtil.parse_obj(updated_metadata_data) -@apply_table_update.register(AddSchemaUpdate) +@_apply_table_update.register(AddSchemaUpdate) def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.last_column_id < base_metadata.last_column_id: raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") @@ -428,7 +428,7 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta return TableMetadataUtil.parse_obj(updated_metadata_data) -@apply_table_update.register(SetCurrentSchemaUpdate) +@_apply_table_update.register(SetCurrentSchemaUpdate) def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: new_schema_id = update.schema_id if new_schema_id == -1: @@ -451,7 +451,7 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta return TableMetadataUtil.parse_obj(updated_metadata_data) -@apply_table_update.register(AddSnapshotUpdate) +@_apply_table_update.register(AddSnapshotUpdate) def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if len(base_metadata.schemas) == 0: raise ValueError("Attempting to add a snapshot before a schema is added") @@ -480,7 +480,7 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe return TableMetadataUtil.parse_obj(updated_metadata_data) -@apply_table_update.register(SetSnapshotRefUpdate) +@_apply_table_update.register(SetSnapshotRefUpdate) def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: snapshot_ref = SnapshotRef( snapshot_id=update.snapshot_id, @@ -534,7 +534,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda new_metadata = base_metadata for update in updates: - new_metadata = apply_table_update(update, new_metadata, context) + new_metadata = _apply_table_update(update, new_metadata, context) return new_metadata diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 7154fcdf1a..939e75e879 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -44,10 +44,10 @@ StaticTable, Table, UpdateSchema, + _apply_table_update, _generate_snapshot_id, _match_deletes_to_datafile, _TableMetadataUpdateContext, - apply_table_update, update_table_metadata, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER @@ -523,7 +523,7 @@ def test_apply_add_schema_update(table_v2: Table) -> None: test_context = _TableMetadataUpdateContext() - new_table_metadata = apply_table_update( + new_table_metadata = _apply_table_update( transaction._updates[0], base_metadata=table_v2.metadata, context=test_context ) # pylint: disable=W0212 assert len(new_table_metadata.schemas) == 3 @@ -532,7 +532,7 @@ def test_apply_add_schema_update(table_v2: Table) -> None: assert test_context._updates[0] == transaction._updates[0] # pylint: disable=W0212 assert test_context.last_added_schema_id == 2 - new_table_metadata = apply_table_update( + new_table_metadata = _apply_table_update( transaction._updates[1], base_metadata=new_table_metadata, context=test_context ) # pylint: disable=W0212 assert len(new_table_metadata.schemas) == 3 From 18aced5885d515c77188bbaf8520b31d49090a6e Mon Sep 17 00:00:00 2001 From: HonahX Date: Sun, 3 Dec 2023 22:51:43 -0800 Subject: [PATCH 21/22] fix an error --- pyiceberg/table/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 07605252b9..1ce268d1c8 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -437,7 +437,7 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta if not context.is_added_schema(new_schema_id): raise ValueError("Cannot set current schema to last added schema when no schema has been added") - if update.schema_id == base_metadata.current_schema_id: + if new_schema_id == base_metadata.current_schema_id: return base_metadata schema = base_metadata.schema_by_id(new_schema_id) From 325eefe0b0c8d7a2b800d185a6c4d937cd0da661 Mon Sep 17 00:00:00 2001 From: HonahX Date: Mon, 4 Dec 2023 08:56:20 -0800 Subject: [PATCH 22/22] remove unnecessary last_added_schema_id --- pyiceberg/table/__init__.py | 5 +---- tests/table/test_init.py | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 1ce268d1c8..9aa6c1c9c5 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -360,11 +360,9 @@ class RemovePropertiesUpdate(TableUpdate): class _TableMetadataUpdateContext: _updates: List[TableUpdate] - last_added_schema_id: Optional[int] def __init__(self) -> None: self._updates = [] - self.last_added_schema_id = None def add_update(self, update: TableUpdate) -> None: self._updates.append(update) @@ -389,7 +387,7 @@ def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, conte Args: update: The update to be applied. base_metadata: The base metadata to be updated. - context: Contains previous updates, last_added_snapshot_id and other change tracking information in the current transaction. + context: Contains previous updates and other change tracking information in the current transaction. Returns: The updated metadata. @@ -424,7 +422,6 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta updated_metadata_data["schemas"].append(update.schema_.model_dump()) context.add_update(update) - context.last_added_schema_id = update.schema_.schema_id return TableMetadataUtil.parse_obj(updated_metadata_data) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 939e75e879..6d188befeb 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -530,7 +530,7 @@ def test_apply_add_schema_update(table_v2: Table) -> None: assert new_table_metadata.current_schema_id == 1 assert len(test_context._updates) == 1 assert test_context._updates[0] == transaction._updates[0] # pylint: disable=W0212 - assert test_context.last_added_schema_id == 2 + assert test_context.is_added_schema(2) new_table_metadata = _apply_table_update( transaction._updates[1], base_metadata=new_table_metadata, context=test_context @@ -539,7 +539,7 @@ def test_apply_add_schema_update(table_v2: Table) -> None: assert new_table_metadata.current_schema_id == 2 assert len(test_context._updates) == 2 assert test_context._updates[1] == transaction._updates[1] # pylint: disable=W0212 - assert test_context.last_added_schema_id == 2 + assert test_context.is_added_schema(2) def test_update_metadata_table_schema(table_v2: Table) -> None: