diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b0f2bcd1c4..6fbde32cc7 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1412,13 +1412,19 @@ def commit(self) -> None: """Apply the pending changes and commit.""" new_schema = self._apply() - if new_schema != self._schema: - last_column_id = max(self._table.metadata.last_column_id, new_schema.highest_field_id) - updates = ( - AddSchemaUpdate(schema=new_schema, last_column_id=last_column_id), - SetCurrentSchemaUpdate(schema_id=-1), - ) + existing_schema_id = next((schema.schema_id for schema in self._table.metadata.schemas if schema == new_schema), None) + + # Check if it is different current schema ID + if existing_schema_id != self._table.schema().schema_id: requirements = (AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),) + if existing_schema_id is None: + last_column_id = max(self._table.metadata.last_column_id, new_schema.highest_field_id) + updates = ( + AddSchemaUpdate(schema=new_schema, last_column_id=last_column_id), + SetCurrentSchemaUpdate(schema_id=-1), + ) + else: + updates = (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) # type: ignore if self._transaction is not None: self._transaction._append_updates(*updates) # pylint: disable=W0212 diff --git a/tests/test_integration_schema.py b/tests/test_integration_schema.py index f0ccb1b0e8..d844e6d6c0 100644 --- a/tests/test_integration_schema.py +++ b/tests/test_integration_schema.py @@ -340,6 +340,34 @@ def test_no_changes_empty_commit(simple_table: Table, table_schema_simple: Schem assert simple_table.schema() == table_schema_simple +@pytest.mark.integration +def test_revert_changes(simple_table: Table, table_schema_simple: Schema) -> None: + with simple_table.update_schema() as update: + update.add_column(path="data", field_type=IntegerType(), required=False) + + with simple_table.update_schema(allow_incompatible_changes=True) as update: + update.delete_column(path="data") + + assert simple_table.schemas() == { + 0: Schema( + NestedField(field_id=1, name='foo', field_type=StringType(), required=False), + NestedField(field_id=2, name='bar', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='baz', field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[2], + ), + 1: Schema( + NestedField(field_id=1, name='foo', field_type=StringType(), required=False), + NestedField(field_id=2, name='bar', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='baz', field_type=BooleanType(), required=False), + NestedField(field_id=4, name='data', field_type=IntegerType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ), + } + assert simple_table.schema().schema_id == 0 + + @pytest.mark.integration def test_delete_field(simple_table: Table) -> None: with simple_table.update_schema() as schema_update: