diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 5adf3a8a48..830637aa99 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -135,6 +135,10 @@ def __repr__(self) -> str: def ref(self) -> BoundReference[L]: return self + def __hash__(self) -> int: + """Return hash value of the BoundReference class.""" + return hash(str(self)) + class UnboundTerm(Term[Any], Unbound[BoundTerm[L]], ABC): """Represents an unbound term.""" diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 2451bf7df7..f3b85eb499 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -73,11 +73,7 @@ from pyiceberg.conversions import to_bytes from pyiceberg.exceptions import ResolveError -from pyiceberg.expressions import ( - AlwaysTrue, - BooleanExpression, - BoundTerm, -) +from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or from pyiceberg.expressions.literals import Literal from pyiceberg.expressions.visitors import ( BoundBooleanExpressionVisitor, @@ -576,11 +572,11 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar: class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]): - def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression: + def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) return pc.field(term.ref().field.name).isin(pyarrow_literals) - def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression: + def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression: pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type)) return ~pc.field(term.ref().field.name).isin(pyarrow_literals) @@ -638,10 +634,152 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p return left_result | right_result +class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]): + # BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr. + is_null_or_not_bound_terms: set[BoundTerm[Any]] + # The remaining BoundTerms appearing in the boolean expr. + null_unmentioned_bound_terms: set[BoundTerm[Any]] + # BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr. + is_nan_or_not_bound_terms: set[BoundTerm[Any]] + # The remaining BoundTerms appearing in the boolean expr. + nan_unmentioned_bound_terms: set[BoundTerm[Any]] + + def __init__(self) -> None: + super().__init__() + self.is_null_or_not_bound_terms = set() + self.null_unmentioned_bound_terms = set() + self.is_nan_or_not_bound_terms = set() + self.nan_unmentioned_bound_terms = set() + + def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None: + """Handle the predicate case where either is_null or is_not_null is included.""" + if term in self.null_unmentioned_bound_terms: + self.null_unmentioned_bound_terms.remove(term) + self.is_null_or_not_bound_terms.add(term) + + def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None: + """Handle the predicate case where neither is_null or is_not_null is included.""" + if term not in self.is_null_or_not_bound_terms: + self.null_unmentioned_bound_terms.add(term) + + def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None: + """Handle the predicate case where either is_nan or is_not_nan is included.""" + if term in self.nan_unmentioned_bound_terms: + self.nan_unmentioned_bound_terms.remove(term) + self.is_nan_or_not_bound_terms.add(term) + + def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None: + """Handle the predicate case where neither is_nan or is_not_nan is included.""" + if term not in self.is_nan_or_not_bound_terms: + self.nan_unmentioned_bound_terms.add(term) + + def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_is_nan(self, term: BoundTerm[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_explicit_is_nan_or_not(term) + + def visit_not_nan(self, term: BoundTerm[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_explicit_is_nan_or_not(term) + + def visit_is_null(self, term: BoundTerm[Any]) -> None: + self._handle_explicit_is_null_or_not(term) + self._handle_nan_unmentioned(term) + + def visit_not_null(self, term: BoundTerm[Any]) -> None: + self._handle_explicit_is_null_or_not(term) + self._handle_nan_unmentioned(term) + + def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None: + self._handle_null_unmentioned(term) + self._handle_nan_unmentioned(term) + + def visit_true(self) -> None: + return + + def visit_false(self) -> None: + return + + def visit_not(self, child_result: None) -> None: + return + + def visit_and(self, left_result: None, right_result: None) -> None: + return + + def visit_or(self, left_result: None, right_result: None) -> None: + return + + def collect( + self, + expr: BooleanExpression, + ) -> None: + """Collect the bound references categorized by having at least one is_null or is_not_null in the expr and the remaining.""" + boolean_expression_visit(expr, self) + + def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression: return boolean_expression_visit(expr, _ConvertToArrowExpression()) +def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression: + """Complementary filter conversion function of expression_to_pyarrow. + + Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null. + """ + collector = _NullNaNUnmentionedTermsCollector() + collector.collect(expr) + + # Convert the set of terms to a sorted list so that layout of the expression to build is deterministic. + null_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted( + collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name + ) + nan_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted( + collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name + ) + + preserve_expr: BooleanExpression = Not(expr) + for term in null_unmentioned_bound_terms: + preserve_expr = Or(preserve_expr, BoundIsNull(term=term)) + for term in nan_unmentioned_bound_terms: + preserve_expr = Or(preserve_expr, BoundIsNaN(term=term)) + return expression_to_pyarrow(preserve_expr) + + @lru_cache def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat: if file_format == FileFormat.PARQUET: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 0cbe4630e4..79af476c91 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -58,7 +58,6 @@ And, BooleanExpression, EqualTo, - Not, Or, Reference, ) @@ -576,7 +575,11 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti delete_filter: A boolean expression to delete rows from a table snapshot_properties: Custom properties to be added to the snapshot summary """ - from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table + from pyiceberg.io.pyarrow import ( + _dataframe_to_data_files, + _expression_to_complementary_pyarrow, + project_table, + ) if ( self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT) @@ -593,7 +596,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti # Check if there are any files that require an actual rewrite of a data file if delete_snapshot.rewrites_needed is True: bound_delete_filter = bind(self._table.schema(), delete_filter, case_sensitive=True) - preserve_row_filter = expression_to_pyarrow(Not(bound_delete_filter)) + preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter) files = self._scan(row_filter=delete_filter).plan_files() diff --git a/tests/integration/test_deletes.py b/tests/integration/test_deletes.py index d8fb01c447..c474de296c 100644 --- a/tests/integration/test_deletes.py +++ b/tests/integration/test_deletes.py @@ -27,7 +27,7 @@ from pyiceberg.manifest import ManifestEntryStatus from pyiceberg.schema import Schema from pyiceberg.table.snapshots import Operation, Summary -from pyiceberg.types import IntegerType, NestedField +from pyiceberg.types import FloatType, IntegerType, NestedField def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None: @@ -105,6 +105,40 @@ def test_partitioned_table_rewrite(spark: SparkSession, session_catalog: RestCat assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [30, 30]} +@pytest.mark.parametrize("format_version", [1, 2]) +def test_rewrite_partitioned_table_with_null(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None: + identifier = "default.table_partitioned_delete" + + run_spark_commands( + spark, + [ + f"DROP TABLE IF EXISTS {identifier}", + f""" + CREATE TABLE {identifier} ( + number_partitioned int, + number int + ) + USING iceberg + PARTITIONED BY (number_partitioned) + TBLPROPERTIES('format-version' = {format_version}) + """, + f""" + INSERT INTO {identifier} VALUES (10, 20), (10, 30) + """, + f""" + INSERT INTO {identifier} VALUES (11, 20), (11, NULL) + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + tbl.delete(EqualTo("number", 20)) + + # We don't delete a whole partition, so there is only a overwrite + assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "append", "overwrite"] + assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [None, 30]} + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_partitioned_table_no_match(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None: @@ -417,3 +451,105 @@ def test_delete_truncate(session_catalog: RestCatalog) -> None: assert len(entries) == 1 assert entries[0].status == ManifestEntryStatus.DELETED + + +def test_delete_overwrite_table_with_null(session_catalog: RestCatalog) -> None: + arrow_schema = pa.schema([pa.field("ints", pa.int32())]) + arrow_tbl = pa.Table.from_pylist( + [{"ints": 1}, {"ints": 2}, {"ints": None}], + schema=arrow_schema, + ) + + iceberg_schema = Schema(NestedField(1, "ints", IntegerType())) + + tbl_identifier = "default.test_delete_overwrite_with_null" + + try: + session_catalog.drop_table(tbl_identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(tbl_identifier, iceberg_schema) + tbl.append(arrow_tbl) + + assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND] + + arrow_tbl_overwrite = pa.Table.from_pylist( + [ + {"ints": 3}, + {"ints": 4}, + ], + schema=arrow_schema, + ) + tbl.overwrite(arrow_tbl_overwrite, "ints == 2") # Should rewrite one file + + assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [ + Operation.APPEND, + Operation.OVERWRITE, + Operation.APPEND, + ] + + assert tbl.scan().to_arrow()["ints"].to_pylist() == [3, 4, 1, None] + + +def test_delete_overwrite_table_with_nan(session_catalog: RestCatalog) -> None: + arrow_schema = pa.schema([pa.field("floats", pa.float32())]) + + # Create Arrow Table with NaN values + data = [pa.array([1.0, float("nan"), 2.0], type=pa.float32())] + arrow_tbl = pa.Table.from_arrays( + data, + schema=arrow_schema, + ) + + iceberg_schema = Schema(NestedField(1, "floats", FloatType())) + + tbl_identifier = "default.test_delete_overwrite_with_nan" + + try: + session_catalog.drop_table(tbl_identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(tbl_identifier, iceberg_schema) + tbl.append(arrow_tbl) + + assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND] + + arrow_tbl_overwrite = pa.Table.from_pylist( + [ + {"floats": 3.0}, + {"floats": 4.0}, + ], + schema=arrow_schema, + ) + """ + We want to test the _expression_to_complementary_pyarrow function can generate a correct complimentary filter + for selecting records to remain in the new overwritten file. + Compared with test_delete_overwrite_table_with_null which tests rows with null cells, + nan testing is faced with a more tricky issue: + A filter of (field == value) will not include cells of nan but col != val will. + (Interestingly, neither == or != will include null) + + This means if we set the test case as floats == 2.0 (equal predicate as in test_delete_overwrite_table_with_null), + test will pass even without the logic under test + in _NullNaNUnmentionedTermsCollector (a helper of _expression_to_complementary_pyarrow + to handle revert of iceberg expression of is_null/not_null/is_nan/not_nan). + Instead, we test the filter of !=, so that the revert is == which exposes the issue. + """ + tbl.overwrite(arrow_tbl_overwrite, "floats != 2.0") # Should rewrite one file + + assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [ + Operation.APPEND, + Operation.OVERWRITE, + Operation.APPEND, + ] + + result = tbl.scan().to_arrow()["floats"].to_pylist() + + from math import isnan + + assert any(isnan(e) for e in result) + assert 2.0 in result + assert 3.0 in result + assert 4.0 in result diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 897af1bbbd..f0a2a45816 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -16,21 +16,35 @@ # under the License. # pylint: disable=protected-access,unused-argument,redefined-outer-name import re +from typing import Any import pyarrow as pa import pytest +from pyiceberg.expressions import ( + And, + BoundEqualTo, + BoundGreaterThan, + BoundIsNaN, + BoundIsNull, + BoundReference, + Not, + Or, +) +from pyiceberg.expressions.literals import literal from pyiceberg.io.pyarrow import ( _ConvertToArrowSchema, _ConvertToIceberg, _ConvertToIcebergWithoutIDs, + _expression_to_complementary_pyarrow, _HasIds, + _NullNaNUnmentionedTermsCollector, _pyarrow_schema_ensure_large_types, pyarrow_to_schema, schema_to_pyarrow, visit_pyarrow, ) -from pyiceberg.schema import Schema, visit +from pyiceberg.schema import Accessor, Schema, visit from pyiceberg.table.name_mapping import MappedField, NameMapping from pyiceberg.types import ( BinaryType, @@ -580,3 +594,127 @@ def test_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids: pa ), ]) assert _pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids) == expected_schema + + +@pytest.fixture +def bound_reference_str() -> BoundReference[Any]: + return BoundReference( + field=NestedField(1, "string_field", StringType(), required=False), accessor=Accessor(position=0, inner=None) + ) + + +@pytest.fixture +def bound_reference_float() -> BoundReference[Any]: + return BoundReference( + field=NestedField(2, "float_field", FloatType(), required=False), accessor=Accessor(position=1, inner=None) + ) + + +@pytest.fixture +def bound_reference_double() -> BoundReference[Any]: + return BoundReference( + field=NestedField(3, "double_field", DoubleType(), required=False), + accessor=Accessor(position=2, inner=None), + ) + + +@pytest.fixture +def bound_eq_str_field(bound_reference_str: BoundReference[Any]) -> BoundEqualTo[Any]: + return BoundEqualTo(term=bound_reference_str, literal=literal("hello")) + + +@pytest.fixture +def bound_greater_than_float_field(bound_reference_float: BoundReference[Any]) -> BoundGreaterThan[Any]: + return BoundGreaterThan(term=bound_reference_float, literal=literal(100)) + + +@pytest.fixture +def bound_is_nan_float_field(bound_reference_float: BoundReference[Any]) -> BoundIsNaN[Any]: + return BoundIsNaN(bound_reference_float) + + +@pytest.fixture +def bound_eq_double_field(bound_reference_double: BoundReference[Any]) -> BoundEqualTo[Any]: + return BoundEqualTo(term=bound_reference_double, literal=literal(False)) + + +@pytest.fixture +def bound_is_null_double_field(bound_reference_double: BoundReference[Any]) -> BoundIsNull[Any]: + return BoundIsNull(bound_reference_double) + + +def test_collect_null_nan_unmentioned_terms( + bound_eq_str_field: BoundEqualTo[Any], bound_is_nan_float_field: BoundIsNaN[Any], bound_is_null_double_field: BoundIsNull[Any] +) -> None: + bound_expr = And( + Or(And(bound_eq_str_field, bound_is_nan_float_field), bound_is_null_double_field), Not(bound_is_nan_float_field) + ) + collector = _NullNaNUnmentionedTermsCollector() + collector.collect(bound_expr) + assert {t.ref().field.name for t in collector.null_unmentioned_bound_terms} == { + "float_field", + "string_field", + } + assert {t.ref().field.name for t in collector.nan_unmentioned_bound_terms} == { + "string_field", + "double_field", + } + assert {t.ref().field.name for t in collector.is_null_or_not_bound_terms} == { + "double_field", + } + assert {t.ref().field.name for t in collector.is_nan_or_not_bound_terms} == {"float_field"} + + +def test_collect_null_nan_unmentioned_terms_with_multiple_predicates_on_the_same_term( + bound_eq_str_field: BoundEqualTo[Any], + bound_greater_than_float_field: BoundGreaterThan[Any], + bound_is_nan_float_field: BoundIsNaN[Any], + bound_eq_double_field: BoundEqualTo[Any], + bound_is_null_double_field: BoundIsNull[Any], +) -> None: + """Test a single term appears multiple places in the expression tree""" + bound_expr = And( + Or( + And(bound_eq_str_field, bound_greater_than_float_field), + And(bound_is_nan_float_field, bound_eq_double_field), + bound_greater_than_float_field, + ), + Not(bound_is_null_double_field), + ) + collector = _NullNaNUnmentionedTermsCollector() + collector.collect(bound_expr) + assert {t.ref().field.name for t in collector.null_unmentioned_bound_terms} == { + "float_field", + "string_field", + } + assert {t.ref().field.name for t in collector.nan_unmentioned_bound_terms} == { + "string_field", + "double_field", + } + assert {t.ref().field.name for t in collector.is_null_or_not_bound_terms} == { + "double_field", + } + assert {t.ref().field.name for t in collector.is_nan_or_not_bound_terms} == {"float_field"} + + +def test_expression_to_complementary_pyarrow( + bound_eq_str_field: BoundEqualTo[Any], + bound_greater_than_float_field: BoundGreaterThan[Any], + bound_is_nan_float_field: BoundIsNaN[Any], + bound_eq_double_field: BoundEqualTo[Any], + bound_is_null_double_field: BoundIsNull[Any], +) -> None: + bound_expr = And( + Or( + And(bound_eq_str_field, bound_greater_than_float_field), + And(bound_is_nan_float_field, bound_eq_double_field), + bound_greater_than_float_field, + ), + Not(bound_is_null_double_field), + ) + result = _expression_to_complementary_pyarrow(bound_expr) + # Notice an isNan predicate on a str column is automatically converted to always false and removed from Or and thus will not appear in the pc.expr. + assert ( + repr(result) + == """ 100)) or (is_nan(float_field) and (double_field == 0))) or (float_field > 100)) and invert(is_null(double_field, {nan_is_null=false})))) or is_null(float_field, {nan_is_null=false})) or is_null(string_field, {nan_is_null=false})) or is_nan(double_field))>""" + )