From 29da39b3c870a42d488f3dec493128a0763aaf2d Mon Sep 17 00:00:00 2001 From: EnyMan Date: Mon, 19 Jan 2026 14:22:16 +0100 Subject: [PATCH 1/8] feat: Optimize upsert process with coarse match filter and vectorized comparisons --- pyiceberg/table/__init__.py | 3 +- pyiceberg/table/upsert_util.py | 148 +++++++++++++++++++++----- tests/table/test_upsert.py | 186 +++++++++++++++++++++++++++++++++ 3 files changed, 311 insertions(+), 26 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b30a1426e7..34a383a22b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -836,7 +836,8 @@ def upsert( ) # get list of rows that exist so we don't have to load the entire target table - matched_predicate = upsert_util.create_match_filter(df, join_cols) + # Use coarse filter for initial scan - exact matching happens in get_rows_to_update() + matched_predicate = upsert_util.create_coarse_match_filter(df, join_cols) # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 6f32826eb0..66ada9e567 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -16,6 +16,7 @@ # under the License. import functools import operator +from typing import Union import pyarrow as pa from pyarrow import Table as pyarrow_table @@ -31,8 +32,18 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: + """ + Create an Iceberg BooleanExpression filter that exactly matches rows based on join columns. + + For single-column keys, uses an efficient In() predicate. + For composite keys, creates Or(And(...), And(...), ...) for exact row matching. + This function should be used when exact matching is required (e.g., overwrite, insert filtering). + """ unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + if len(unique_keys) == 0: + return AlwaysFalse() + if len(join_cols) == 1: return In(join_cols[0], unique_keys[0].to_pylist()) else: @@ -48,17 +59,97 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre return Or(*filters) +def create_coarse_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: + """ + Create a coarse Iceberg BooleanExpression filter for initial row scanning. + + For single-column keys, uses an efficient In() predicate (exact match). + For composite keys, uses In() per column as a coarse filter (AND of In() predicates), + which may return false positives but is much more efficient than exact matching. + + This function should only be used for initial scans where exact matching happens + downstream (e.g., in get_rows_to_update() via the join operation). + """ + unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + + if len(unique_keys) == 0: + return AlwaysFalse() + + if len(join_cols) == 1: + return In(join_cols[0], unique_keys[0].to_pylist()) + else: + # For composite keys: use In() per column as a coarse filter + # This is more efficient than creating Or(And(...), And(...), ...) for each row + # May include false positives, but fine-grained matching happens downstream + column_filters = [] + for col in join_cols: + unique_values = pc.unique(unique_keys[col]).to_pylist() + column_filters.append(In(col, unique_values)) + return functools.reduce(operator.and_, column_filters) + + def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: """Check for duplicate rows in a PyArrow table based on the join columns.""" return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0 +def _compare_columns_vectorized( + source_col: Union[pa.Array, pa.ChunkedArray], target_col: Union[pa.Array, pa.ChunkedArray] +) -> pa.Array: + """ + Vectorized comparison of two columns, returning a boolean array where True means values differ. + + Handles struct types recursively by comparing each nested field. + Handles null values correctly: null != non-null is True, null == null is True (no update needed). + """ + col_type = source_col.type + + if pa.types.is_struct(col_type): + # PyArrow cannot directly compare struct columns, so we recursively compare each field + diff_masks = [] + for i, field in enumerate(col_type): + src_field = pc.struct_field(source_col, [i]) + tgt_field = pc.struct_field(target_col, [i]) + field_diff = _compare_columns_vectorized(src_field, tgt_field) + diff_masks.append(field_diff) + + if not diff_masks: + # Empty struct - no fields to compare, so no differences + return pa.array([False] * len(source_col), type=pa.bool_()) + + return functools.reduce(pc.or_, diff_masks) + + elif pa.types.is_list(col_type) or pa.types.is_large_list(col_type) or pa.types.is_map(col_type): + # For list/map types, fall back to Python comparison as PyArrow doesn't support vectorized comparison + # This is still faster than the original row-by-row approach since we batch the conversion + source_py = source_col.to_pylist() + target_py = target_col.to_pylist() + return pa.array([s != t for s, t in zip(source_py, target_py, strict=True)], type=pa.bool_()) + + else: + # For primitive types, use vectorized not_equal + # Handle nulls: not_equal returns null when comparing with null + # We need: null vs non-null = different (True), null vs null = same (False) + diff = pc.not_equal(source_col, target_col) + source_null = pc.is_null(source_col) + target_null = pc.is_null(target_col) + + # XOR of null masks: True if exactly one is null (meaning they differ) + null_diff = pc.xor(source_null, target_null) + + # Combine: different if values differ OR exactly one is null + # Fill null comparison results with False (both non-null but comparison returned null shouldn't happen, + # but if it does, treat as no difference) + diff_filled = pc.fill_null(diff, False) + return pc.or_(diff_filled, null_diff) + + def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table: """ Return a table with rows that need to be updated in the target table based on the join columns. + Uses vectorized PyArrow operations for efficient comparison, avoiding row-by-row Python loops. The table is joined on the identifier columns, and then checked if there are any updated rows. - Those are selected and everything is renamed correctly. """ all_columns = set(source_table.column_names) join_cols_set = set(join_cols) @@ -69,13 +160,13 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols raise ValueError("Target table has duplicate rows, aborting upsert") if len(target_table) == 0: - # When the target table is empty, there is nothing to update :) + # When the target table is empty, there is nothing to update + return source_table.schema.empty_table() + + if len(non_key_cols) == 0: + # No non-key columns to compare, all matched rows are "updates" but with no changes return source_table.schema.empty_table() - # We need to compare non_key_cols in Python as PyArrow - # 1. Cannot do a join when non-join columns have complex types - # 2. Cannot compare columns with complex types - # See: https://github.com/apache/arrow/issues/35785 SOURCE_INDEX_COLUMN_NAME = "__source_index" TARGET_INDEX_COLUMN_NAME = "__target_index" @@ -100,25 +191,32 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # Step 3: Perform an inner join to find which rows from source exist in target matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") - # Step 4: Compare all rows using Python - to_update_indices = [] - for source_idx, target_idx in zip( - matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), - matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist(), - strict=True, - ): - source_row = source_table.slice(source_idx, 1) - target_row = target_table.slice(target_idx, 1) - - for key in non_key_cols: - source_val = source_row.column(key)[0].as_py() - target_val = target_row.column(key)[0].as_py() - if source_val != target_val: - to_update_indices.append(source_idx) - break - - # Step 5: Take rows from source table using the indices and cast to target schema - if to_update_indices: + if len(matching_indices) == 0: + # No matching rows found + return source_table.schema.empty_table() + + # Step 4: Take matched rows in batch (vectorized - single operation) + source_indices = matching_indices[SOURCE_INDEX_COLUMN_NAME] + target_indices = matching_indices[TARGET_INDEX_COLUMN_NAME] + + matched_source = source_table.take(source_indices) + matched_target = target_table.take(target_indices) + + # Step 5: Vectorized comparison per column + diff_masks = [] + for col in non_key_cols: + source_col = matched_source.column(col) + target_col = matched_target.column(col) + col_diff = _compare_columns_vectorized(source_col, target_col) + diff_masks.append(col_diff) + + # Step 6: Combine masks with OR (any column different = needs update) + combined_mask = functools.reduce(pc.or_, diff_masks) + + # Step 7: Filter to get indices of rows that need updating + to_update_indices = pc.filter(source_indices, combined_mask) + + if len(to_update_indices) > 0: return source_table.take(to_update_indices) else: return source_table.schema.empty_table() diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 9bc61799e4..1682af3b06 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -885,3 +885,189 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None: for snapshot in snapshots[initial_snapshot_count:]: assert snapshot.summary is not None assert snapshot.summary.additional_properties.get("test_prop") == "test_value" + + +def test_coarse_match_filter_composite_key() -> None: + """ + Test that create_coarse_match_filter produces efficient In() predicates for composite keys. + """ + from pyiceberg.table.upsert_util import create_coarse_match_filter, create_match_filter + + # Create a table with composite key that has overlapping values + # (1, 'x'), (2, 'y'), (1, 'z') - exact filter should have 3 conditions + # coarse filter should have In(a, [1,2]) AND In(b, ['x','y','z']) + data = [ + {"a": 1, "b": "x", "val": 1}, + {"a": 2, "b": "y", "val": 2}, + {"a": 1, "b": "z", "val": 3}, + ] + schema = pa.schema([pa.field("a", pa.int32()), pa.field("b", pa.string()), pa.field("val", pa.int32())]) + table = pa.Table.from_pylist(data, schema=schema) + + exact_filter = create_match_filter(table, ["a", "b"]) + coarse_filter = create_coarse_match_filter(table, ["a", "b"]) + + # Exact filter is an Or of And conditions + assert "Or" in str(exact_filter) + + # Coarse filter is an And of In conditions + assert "And" in str(coarse_filter) + assert "In" in str(coarse_filter) + + +def test_vectorized_comparison_primitives() -> None: + """Test vectorized comparison with primitive types.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + # Test integers + source = pa.array([1, 2, 3, 4]) + target = pa.array([1, 2, 5, 4]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, True, False] + + # Test strings + source = pa.array(["a", "b", "c"]) + target = pa.array(["a", "x", "c"]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # Test floats + source = pa.array([1.0, 2.5, 3.0]) + target = pa.array([1.0, 2.5, 3.1]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, True] + + +def test_vectorized_comparison_nulls() -> None: + """Test vectorized comparison handles nulls correctly.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + # null vs non-null = different + source = pa.array([1, None, 3]) + target = pa.array([1, 2, 3]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # non-null vs null = different + source = pa.array([1, 2, 3]) + target = pa.array([1, None, 3]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # null vs null = same (no update needed) + source = pa.array([1, None, 3]) + target = pa.array([1, None, 3]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, False] + + +def test_vectorized_comparison_structs() -> None: + """Test vectorized comparison with nested struct types.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + struct_type = pa.struct([("x", pa.int32()), ("y", pa.string())]) + + # Same structs + source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False] + + # Different struct values + source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True] + + +def test_vectorized_comparison_nested_structs() -> None: + """Test vectorized comparison with deeply nested struct types.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + inner_struct = pa.struct([("val", pa.int32())]) + outer_struct = pa.struct([("inner", inner_struct), ("name", pa.string())]) + + source = pa.array( + [{"inner": {"val": 1}, "name": "a"}, {"inner": {"val": 2}, "name": "b"}], + type=outer_struct, + ) + target = pa.array( + [{"inner": {"val": 1}, "name": "a"}, {"inner": {"val": 3}, "name": "b"}], + type=outer_struct, + ) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True] + + +def test_vectorized_comparison_lists() -> None: + """Test vectorized comparison with list types (falls back to Python comparison).""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + list_type = pa.list_(pa.int32()) + + source = pa.array([[1, 2], [3, 4]], type=list_type) + target = pa.array([[1, 2], [3, 5]], type=list_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True] + + +def test_get_rows_to_update_no_non_key_cols() -> None: + """Test get_rows_to_update when all columns are key columns.""" + from pyiceberg.table.upsert_util import get_rows_to_update + + # All columns are key columns, so no non-key columns to compare + source = pa.Table.from_pydict({"id": [1, 2, 3]}) + target = pa.Table.from_pydict({"id": [1, 2, 3]}) + rows = get_rows_to_update(source, target, ["id"]) + assert len(rows) == 0 + + +def test_upsert_with_list_field(catalog: Catalog) -> None: + """Test upsert with list type as non-key column.""" + from pyiceberg.types import ListType + + identifier = "default.test_upsert_with_list_field" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "tags", + ListType(element_id=3, element_type=StringType(), element_required=False), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field("tags", pa.list_(pa.large_string()), nullable=True), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + {"id": 1, "tags": ["a", "b"]}, + {"id": 2, "tags": ["c"]}, + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + # Update with changed list + update_data = pa.Table.from_pylist( + [ + {"id": 1, "tags": ["a", "b"]}, # Same - no update + {"id": 2, "tags": ["c", "d"]}, # Changed - should update + {"id": 3, "tags": ["e"]}, # New - should insert + ], + schema=arrow_schema, + ) + + res = tbl.upsert(update_data, join_cols=["id"]) + assert res.rows_updated == 1 + assert res.rows_inserted == 1 From b31676b9b06ab14ab2442d056c5151750dd8c0c3 Mon Sep 17 00:00:00 2001 From: EnyMan Date: Mon, 19 Jan 2026 16:02:21 +0100 Subject: [PATCH 2/8] feat: Enhance vectorized comparison to handle struct-level nulls and empty structs --- pyiceberg/table/upsert_util.py | 18 ++++++++---- tests/table/test_upsert.py | 50 ++++++++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 66ada9e567..ef2661e49f 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -16,7 +16,6 @@ # under the License. import functools import operator -from typing import Union import pyarrow as pa from pyarrow import Table as pyarrow_table @@ -94,7 +93,7 @@ def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: def _compare_columns_vectorized( - source_col: Union[pa.Array, pa.ChunkedArray], target_col: Union[pa.Array, pa.ChunkedArray] + source_col: pa.Array | pa.ChunkedArray, target_col: pa.Array | pa.ChunkedArray ) -> pa.Array: """ Vectorized comparison of two columns, returning a boolean array where True means values differ. @@ -105,6 +104,11 @@ def _compare_columns_vectorized( col_type = source_col.type if pa.types.is_struct(col_type): + # Handle struct-level nulls first + source_null = pc.is_null(source_col) + target_null = pc.is_null(target_col) + struct_null_diff = pc.xor(source_null, target_null) # Different if exactly one is null + # PyArrow cannot directly compare struct columns, so we recursively compare each field diff_masks = [] for i, field in enumerate(col_type): @@ -114,12 +118,14 @@ def _compare_columns_vectorized( diff_masks.append(field_diff) if not diff_masks: - # Empty struct - no fields to compare, so no differences - return pa.array([False] * len(source_col), type=pa.bool_()) + # Empty struct - only null differences matter + return struct_null_diff - return functools.reduce(pc.or_, diff_masks) + # Combine field differences with struct-level null differences + field_diff = functools.reduce(pc.or_, diff_masks) + return pc.or_(field_diff, struct_null_diff) - elif pa.types.is_list(col_type) or pa.types.is_large_list(col_type) or pa.types.is_map(col_type): + elif pa.types.is_list(col_type) or pa.types.is_large_list(col_type) or pa.types.is_fixed_size_list(col_type) or pa.types.is_map(col_type): # For list/map types, fall back to Python comparison as PyArrow doesn't support vectorized comparison # This is still faster than the original row-by-row approach since we batch the conversion source_py = source_col.to_pylist() diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 1682af3b06..e4b2fd4377 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -892,6 +892,7 @@ def test_coarse_match_filter_composite_key() -> None: Test that create_coarse_match_filter produces efficient In() predicates for composite keys. """ from pyiceberg.table.upsert_util import create_coarse_match_filter, create_match_filter + from pyiceberg.expressions import Or, And, In # Create a table with composite key that has overlapping values # (1, 'x'), (2, 'y'), (1, 'z') - exact filter should have 3 conditions @@ -908,10 +909,10 @@ def test_coarse_match_filter_composite_key() -> None: coarse_filter = create_coarse_match_filter(table, ["a", "b"]) # Exact filter is an Or of And conditions - assert "Or" in str(exact_filter) + assert isinstance(exact_filter, Or) # Coarse filter is an And of In conditions - assert "And" in str(coarse_filter) + assert isinstance(coarse_filter, And) assert "In" in str(coarse_filter) @@ -1071,3 +1072,48 @@ def test_upsert_with_list_field(catalog: Catalog) -> None: res = tbl.upsert(update_data, join_cols=["id"]) assert res.rows_updated == 1 assert res.rows_inserted == 1 + + +def test_vectorized_comparison_struct_level_nulls() -> None: + """Test vectorized comparison handles struct-level nulls correctly (not just field-level nulls).""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + struct_type = pa.struct([("x", pa.int32()), ("y", pa.string())]) + + # null struct vs non-null struct = different + source = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # non-null struct vs null struct = different + source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # null struct vs null struct = same (no update needed) + source = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, False] + + +def test_vectorized_comparison_empty_struct_with_nulls() -> None: + """Test that empty structs with null values are compared correctly.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + # Empty struct type - edge case where only struct-level null handling matters + empty_struct_type = pa.struct([]) + + # null vs non-null empty struct = different + source = pa.array([{}, None, {}], type=empty_struct_type) + target = pa.array([{}, {}, {}], type=empty_struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # null vs null empty struct = same + source = pa.array([None, None], type=empty_struct_type) + target = pa.array([None, None], type=empty_struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False] From d3a6d74e850929d3c98d10a079d684fcbd245af9 Mon Sep 17 00:00:00 2001 From: EnyMan Date: Mon, 19 Jan 2026 17:15:43 +0100 Subject: [PATCH 3/8] feat: Add logging for performance metrics in upsert and get_rows_to_update functions --- pyiceberg/table/__init__.py | 46 ++++++++++++++++++++++++++++++++++ pyiceberg/table/upsert_util.py | 31 +++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 34a383a22b..6f6167ba61 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -17,7 +17,9 @@ from __future__ import annotations import itertools +import logging import os +import time import uuid import warnings from abc import ABC, abstractmethod @@ -154,6 +156,8 @@ ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" +logger = logging.getLogger(__name__) + @dataclass() class UpsertResult: @@ -521,6 +525,7 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: + write_start = time.perf_counter() data_files = list( _dataframe_to_data_files( table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io @@ -528,6 +533,8 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, ) for data_file in data_files: append_files.append_data_file(data_file) + write_end = time.perf_counter() + logger.info(f"Append data file writing: {write_end - write_start:.3f}s ({df.shape[0]} rows)") def dynamic_partition_overwrite( self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH @@ -636,21 +643,27 @@ def overwrite( if overwrite_filter != AlwaysFalse(): # Only delete when the filter is != AlwaysFalse + delete_start = time.perf_counter() self.delete( delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties, branch=branch, ) + delete_end = time.perf_counter() + logger.info(f"Overwrite delete operation: {delete_end - delete_start:.3f}s") with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: + write_start = time.perf_counter() data_files = _dataframe_to_data_files( table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io ) for data_file in data_files: append_files.append_data_file(data_file) + write_end = time.perf_counter() + logger.info(f"Overwrite data file writing: {write_end - write_start:.3f}s ({df.shape[0]} rows)") def delete( self, @@ -799,6 +812,8 @@ def upsert( Returns: An UpsertResult class (contains details of rows updated and inserted) """ + upsert_start = time.perf_counter() + try: import pyarrow as pa # noqa: F401 except ModuleNotFoundError as e: @@ -835,10 +850,16 @@ def upsert( format_version=self.table_metadata.format_version, ) + setup_end = time.perf_counter() + logger.info(f"Upsert setup (join cols, validation): {setup_end - upsert_start:.3f}s") + # get list of rows that exist so we don't have to load the entire target table # Use coarse filter for initial scan - exact matching happens in get_rows_to_update() matched_predicate = upsert_util.create_coarse_match_filter(df, join_cols) + coarse_filter_end = time.perf_counter() + logger.info(f"Coarse match filter creation: {coarse_filter_end - setup_end:.3f}s") + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. matched_iceberg_record_batches_scan = DataScan( @@ -851,20 +872,30 @@ def upsert( if branch in self.table_metadata.refs: matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) + scan_start = time.perf_counter() matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() + scan_end = time.perf_counter() + logger.info(f"Scan setup (to_arrow_batch_reader): {scan_end - scan_start:.3f}s") batches_to_overwrite = [] overwrite_predicates = [] rows_to_insert = df + batch_loop_start = time.perf_counter() + batch_count = 0 + total_rows_to_update_time = 0.0 + for batch in matched_iceberg_record_batches: + batch_count += 1 rows = pa.Table.from_batches([batch]) if when_matched_update_all: # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed # this extra step avoids unnecessary IO and writes + rows_to_update_start = time.perf_counter() rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols) + total_rows_to_update_time += time.perf_counter() - rows_to_update_start if len(rows_to_update) > 0: # build the match predicate filter @@ -881,23 +912,38 @@ def upsert( # Filter rows per batch. rows_to_insert = rows_to_insert.filter(~expr_match_arrow) + batch_loop_end = time.perf_counter() + logger.info( + f"Batch processing: {batch_loop_end - batch_loop_start:.3f}s " + f"({batch_count} batches, get_rows_to_update total: {total_rows_to_update_time:.3f}s)" + ) + update_row_cnt = 0 insert_row_cnt = 0 if batches_to_overwrite: rows_to_update = pa.concat_tables(batches_to_overwrite) update_row_cnt = len(rows_to_update) + overwrite_start = time.perf_counter() self.overwrite( rows_to_update, overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], branch=branch, snapshot_properties=snapshot_properties, ) + overwrite_end = time.perf_counter() + logger.info(f"Overwrite: {overwrite_end - overwrite_start:.3f}s ({update_row_cnt} rows)") if when_not_matched_insert_all: insert_row_cnt = len(rows_to_insert) if rows_to_insert: + append_start = time.perf_counter() self.append(rows_to_insert, branch=branch, snapshot_properties=snapshot_properties) + append_end = time.perf_counter() + logger.info(f"Append: {append_end - append_start:.3f}s ({insert_row_cnt} rows)") + + upsert_end = time.perf_counter() + logger.info(f"Total upsert: {upsert_end - upsert_start:.3f}s (updated: {update_row_cnt}, inserted: {insert_row_cnt})") return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index ef2661e49f..71c505321d 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -15,7 +15,11 @@ # specific language governing permissions and limitations # under the License. import functools +import logging import operator +import time + +logger = logging.getLogger(__name__) import pyarrow as pa from pyarrow import Table as pyarrow_table @@ -157,6 +161,8 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols Uses vectorized PyArrow operations for efficient comparison, avoiding row-by-row Python loops. The table is joined on the identifier columns, and then checked if there are any updated rows. """ + func_start = time.perf_counter() + all_columns = set(source_table.column_names) join_cols_set = set(join_cols) @@ -167,10 +173,12 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols if len(target_table) == 0: # When the target table is empty, there is nothing to update + logger.info(f"get_rows_to_update: {time.perf_counter() - func_start:.3f}s (empty target table)") return source_table.schema.empty_table() if len(non_key_cols) == 0: # No non-key columns to compare, all matched rows are "updates" but with no changes + logger.info(f"get_rows_to_update: {time.perf_counter() - func_start:.3f}s (no non-key columns)") return source_table.schema.empty_table() SOURCE_INDEX_COLUMN_NAME = "__source_index" @@ -185,6 +193,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # Step 1: Prepare source index with join keys and a marker index # Cast to target table schema, so we can do the join # See: https://github.com/apache/arrow/issues/37542 + index_start = time.perf_counter() source_index = ( source_table.cast(target_table.schema) .select(join_cols_set) @@ -193,22 +202,33 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # Step 2: Prepare target index with join keys and a marker target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) + index_end = time.perf_counter() # Step 3: Perform an inner join to find which rows from source exist in target + join_start = time.perf_counter() matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + join_end = time.perf_counter() if len(matching_indices) == 0: # No matching rows found + logger.info( + f"get_rows_to_update: {time.perf_counter() - func_start:.3f}s " + f"(index prep: {index_end - index_start:.3f}s, join: {join_end - join_start:.3f}s, " + f"matched: 0, to_update: 0)" + ) return source_table.schema.empty_table() # Step 4: Take matched rows in batch (vectorized - single operation) + take_start = time.perf_counter() source_indices = matching_indices[SOURCE_INDEX_COLUMN_NAME] target_indices = matching_indices[TARGET_INDEX_COLUMN_NAME] matched_source = source_table.take(source_indices) matched_target = target_table.take(target_indices) + take_end = time.perf_counter() # Step 5: Vectorized comparison per column + compare_start = time.perf_counter() diff_masks = [] for col in non_key_cols: source_col = matched_source.column(col) @@ -218,10 +238,21 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # Step 6: Combine masks with OR (any column different = needs update) combined_mask = functools.reduce(pc.or_, diff_masks) + compare_end = time.perf_counter() # Step 7: Filter to get indices of rows that need updating to_update_indices = pc.filter(source_indices, combined_mask) + func_end = time.perf_counter() + logger.info( + f"get_rows_to_update: {func_end - func_start:.3f}s " + f"(index prep: {index_end - index_start:.3f}s, " + f"join: {join_end - join_start:.3f}s, " + f"take: {take_end - take_start:.3f}s, " + f"compare: {compare_end - compare_start:.3f}s, " + f"matched: {len(matching_indices)}, to_update: {len(to_update_indices)})" + ) + if len(to_update_indices) > 0: return source_table.take(to_update_indices) else: From 39d0954e593718d292d3c8d861c37bfa1bb6eab2 Mon Sep 17 00:00:00 2001 From: EnyMan Date: Mon, 19 Jan 2026 20:52:46 +0100 Subject: [PATCH 4/8] feat: Optimize insert filtering in upsert process using anti-join for matched keys --- pyiceberg/table/__init__.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 6f6167ba61..19f3cfa1e9 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -819,7 +819,6 @@ def upsert( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - from pyiceberg.io.pyarrow import expression_to_pyarrow from pyiceberg.table import upsert_util if join_cols is None: @@ -879,7 +878,7 @@ def upsert( batches_to_overwrite = [] overwrite_predicates = [] - rows_to_insert = df + matched_target_keys: list[pa.Table] = [] # Accumulate matched keys for insert filtering batch_loop_start = time.perf_counter() batch_count = 0 @@ -904,13 +903,9 @@ def upsert( batches_to_overwrite.append(rows_to_update) overwrite_predicates.append(overwrite_mask_predicate) + # Collect matched keys for insert filtering (will use anti-join after loop) if when_not_matched_insert_all: - expr_match = upsert_util.create_match_filter(rows, join_cols) - expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) - - # Filter rows per batch. - rows_to_insert = rows_to_insert.filter(~expr_match_arrow) + matched_target_keys.append(rows.select(join_cols)) batch_loop_end = time.perf_counter() logger.info( @@ -918,6 +913,25 @@ def upsert( f"({batch_count} batches, get_rows_to_update total: {total_rows_to_update_time:.3f}s)" ) + # Use anti-join to find rows to insert (replaces per-batch expression filtering) + rows_to_insert = df + if when_not_matched_insert_all and matched_target_keys: + filter_start = time.perf_counter() + # Combine all matched keys and deduplicate + combined_matched_keys = pa.concat_tables(matched_target_keys).group_by(join_cols).aggregate([]) + # Cast matched keys to source schema types for join compatibility + source_key_schema = df.select(join_cols).schema + combined_matched_keys = combined_matched_keys.cast(source_key_schema) + # Use anti-join on key columns only (with row indices) to avoid issues with + # struct/list types in non-key columns that PyArrow join doesn't support + row_indices = pa.chunked_array([pa.array(range(len(df)), type=pa.int64())]) + source_keys_with_idx = df.select(join_cols).append_column("__row_idx__", row_indices) + not_matched_keys = source_keys_with_idx.join(combined_matched_keys, keys=join_cols, join_type="left anti") + indices_to_keep = not_matched_keys.column("__row_idx__").combine_chunks() + rows_to_insert = df.take(indices_to_keep) + filter_end = time.perf_counter() + logger.info(f"Insert filtering (anti-join): {filter_end - filter_start:.3f}s ({len(combined_matched_keys)} matched keys)") + update_row_cnt = 0 insert_row_cnt = 0 From 0b89237f997f15c0f850fa29bb7ef09b216956e0 Mon Sep 17 00:00:00 2001 From: EnyMan Date: Tue, 20 Jan 2026 15:04:55 +0100 Subject: [PATCH 5/8] Add more logging --- pyiceberg/io/pyarrow.py | 62 ++++++++++- pyiceberg/table/__init__.py | 208 +++++++++++++++++++++++++++++++++++- 2 files changed, 265 insertions(+), 5 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 55ecc7ac93..84ca3a25d4 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -34,6 +34,7 @@ import operator import os import re +import time import uuid import warnings from abc import ABC, abstractmethod @@ -1575,11 +1576,16 @@ def _task_to_record_batches( format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, ) -> Iterator[pa.RecordBatch]: + task_start = time.perf_counter() + + open_start = time.perf_counter() arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: fragment = arrow_format.make_fragment(fin) physical_schema = fragment.physical_schema + open_end = time.perf_counter() + schema_start = time.perf_counter() # For V1 and V2, we only support Timestamp 'us' in Iceberg Schema, therefore it is reasonable to always cast 'ns' timestamp to 'us' on read. # For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default downcast_ns_timestamp_to_us = ( @@ -1593,7 +1599,9 @@ def _task_to_record_batches( projected_missing_fields = _get_column_projection_values( task.file, projected_schema, table_schema, partition_spec, file_schema.field_ids ) + schema_end = time.perf_counter() + filter_start = time.perf_counter() pyarrow_filter = None if bound_row_filter is not AlwaysTrue(): translated_row_filter = translate_column_names( @@ -1612,9 +1620,24 @@ def _task_to_record_batches( filter=pyarrow_filter if not positional_deletes else None, columns=[col.name for col in file_project_schema.columns], ) + filter_end = time.perf_counter() + + batch_read_start = time.perf_counter() + batches = list(fragment_scanner.to_batches()) + batch_read_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] _task_to_record_batches %s: (open: %.4fs, schema: %.4fs, filter_prep: %.4fs, batch_read: %.4fs, batches: %d, total: %.4fs)", + task.file.file_path, + open_end - open_start, + schema_end - schema_start, + filter_end - filter_start, + batch_read_end - batch_read_start, + len(batches), + time.perf_counter() - task_start, + ) next_index = 0 - batches = fragment_scanner.to_batches() for batch in batches: next_index = next_index + len(batch) current_index = next_index - len(batch) @@ -1650,9 +1673,14 @@ def _task_to_record_batches( def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]: + func_start = time.perf_counter() + deletes_per_file: dict[str, list[ChunkedArray]] = {} unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) + collect_end = time.perf_counter() + if len(unique_deletes) > 0: + read_start = time.perf_counter() executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map( lambda args: _read_deletes(*args), @@ -1664,6 +1692,20 @@ def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[st deletes_per_file[file].append(arr) else: deletes_per_file[file] = [arr] + read_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] _read_all_delete_files: %.4fs (collect: %.4fs, read: %.4fs, unique_delete_files: %d)", + time.perf_counter() - func_start, + collect_end - func_start, + read_end - read_start, + len(unique_deletes), + ) + else: + logger.info( + "[SCAN TIMING] _read_all_delete_files: %.4fs (no deletes)", + time.perf_counter() - func_start, + ) return deletes_per_file @@ -1773,7 +1815,21 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._io, tasks) + materialize_start = time.perf_counter() + tasks_list = list(tasks) # Force materialization for delete file collection + materialize_end = time.perf_counter() + + delete_read_start = time.perf_counter() + deletes_per_file = _read_all_delete_files(self._io, tasks_list) + delete_read_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] to_record_batches setup: (task_materialize: %.4fs, delete_read: %.4fs, tasks: %d, delete_files: %d)", + materialize_end - materialize_start, + delete_read_end - delete_read_start, + len(tasks_list), + len(deletes_per_file), + ) total_row_count = 0 executor = ExecutorFactory.get_or_create() @@ -1785,7 +1841,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) limit_reached = False - for batches in executor.map(batches_for_task, tasks): + for batches in executor.map(batches_for_task, tasks_list): for batch in batches: current_batch_size = len(batch) if self._limit is not None and total_row_count + current_batch_size >= self._limit: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 19f3cfa1e9..3eef488a81 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -766,6 +766,178 @@ def delete( if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records", stacklevel=2) + def upsert_unoptimized( + self, + df: pa.Table, + join_cols: list[str] | None = None, + when_matched_update_all: bool = True, + when_not_matched_insert_all: bool = True, + case_sensitive: bool = True, + branch: str | None = MAIN_BRANCH, + snapshot_properties: dict[str, str] = EMPTY_DICT, + ) -> UpsertResult: + """Shorthand API for performing an upsert to an iceberg table. + + Args: + + df: The input dataframe to upsert with the table's data. + join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. + when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing + when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table + case_sensitive: Bool indicating if the match should be case-sensitive + branch: Branch Reference to run the upsert operation + snapshot_properties: Custom properties to be added to the snapshot summary + + To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids + + Example Use Cases: + Case 1: Both Parameters = True (Full Upsert) + Existing row found → Update it + New row found → Insert it + + Case 2: when_matched_update_all = False, when_not_matched_insert_all = True + Existing row found → Do nothing (no updates) + New row found → Insert it + + Case 3: when_matched_update_all = True, when_not_matched_insert_all = False + Existing row found → Update it + New row found → Do nothing (no inserts) + + Case 4: Both Parameters = False (No Merge Effect) + Existing row found → Do nothing + New row found → Do nothing + (Function effectively does nothing) + + + Returns: + An UpsertResult class (contains details of rows updated and inserted) + """ + upsert_start = time.perf_counter() + try: + import pyarrow as pa # noqa: F401 + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + from pyiceberg.io.pyarrow import expression_to_pyarrow + from pyiceberg.table import upsert_util + + if join_cols is None: + join_cols = [] + for field_id in self.table_metadata.schema().identifier_field_ids: + col = self.table_metadata.schema().find_column_name(field_id) + if col is not None: + join_cols.append(col) + else: + raise ValueError(f"Field-ID could not be found: {join_cols}") + + if len(join_cols) == 0: + raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") + + if not when_matched_update_all and not when_not_matched_insert_all: + raise ValueError("no upsert options selected...exiting") + + if upsert_util.has_duplicate_rows(df, join_cols): + raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") + + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible + + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_pyarrow_schema_compatible( + self.table_metadata.schema(), + provided_schema=df.schema, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + format_version=self.table_metadata.format_version, + ) + + # get list of rows that exist so we don't have to load the entire target table + t0 = time.perf_counter() + matched_predicate = upsert_util.create_match_filter(df, join_cols) + logger.info(f"[UPSERT TIMING] create_match_filter (initial): {time.perf_counter() - t0:.4f}s") + + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. + + t0 = time.perf_counter() + matched_iceberg_record_batches_scan = DataScan( + table_metadata=self.table_metadata, + io=self._table.io, + row_filter=matched_predicate, + case_sensitive=case_sensitive, + ) + + if branch in self.table_metadata.refs: + matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) + logger.info(f"[UPSERT TIMING] datascan_setup: {time.perf_counter() - t0:.4f}s") + + t0 = time.perf_counter() + matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() + logger.info(f"[UPSERT TIMING] to_arrow_batch_reader: {time.perf_counter() - t0:.4f}s") + + batches_to_overwrite = [] + overwrite_predicates = [] + rows_to_insert = df + + batch_loop_start = time.perf_counter() + batch_idx = 0 + for batch in matched_iceberg_record_batches: + rows = pa.Table.from_batches([batch]) + + if when_matched_update_all: + # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed + # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed + # this extra step avoids unnecessary IO and writes + t0 = time.perf_counter() + rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols) + logger.info(f"[UPSERT TIMING] batch {batch_idx}: get_rows_to_update: {time.perf_counter() - t0:.4f}s") + + if len(rows_to_update) > 0: + # build the match predicate filter + t0 = time.perf_counter() + overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) + logger.info(f"[UPSERT TIMING] batch {batch_idx}: create_match_filter (overwrite): {time.perf_counter() - t0:.4f}s") + + batches_to_overwrite.append(rows_to_update) + overwrite_predicates.append(overwrite_mask_predicate) + + if when_not_matched_insert_all: + t0 = time.perf_counter() + expr_match = upsert_util.create_match_filter(rows, join_cols) + logger.info(f"[UPSERT TIMING] batch {batch_idx}: create_match_filter (insert): {time.perf_counter() - t0:.4f}s") + expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) + expr_match_arrow = expression_to_pyarrow(expr_match_bound) + + # Filter rows per batch. + t0 = time.perf_counter() + rows_to_insert = rows_to_insert.filter(~expr_match_arrow) + logger.info(f"[UPSERT TIMING] batch {batch_idx}: filter_rows_to_insert: {time.perf_counter() - t0:.4f}s") + + batch_idx += 1 + logger.info(f"[UPSERT TIMING] batch_loop_total: {time.perf_counter() - batch_loop_start:.4f}s ({batch_idx} iterations)") + + update_row_cnt = 0 + insert_row_cnt = 0 + + if batches_to_overwrite: + rows_to_update = pa.concat_tables(batches_to_overwrite) + update_row_cnt = len(rows_to_update) + t0 = time.perf_counter() + self.overwrite( + rows_to_update, + overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], + branch=branch, + snapshot_properties=snapshot_properties, + ) + logger.info(f"[UPSERT TIMING] overwrite_operation: {time.perf_counter() - t0:.4f}s") + + if when_not_matched_insert_all: + insert_row_cnt = len(rows_to_insert) + if rows_to_insert: + t0 = time.perf_counter() + self.append(rows_to_insert, branch=branch, snapshot_properties=snapshot_properties) + logger.info(f"[UPSERT TIMING] append_operation: {time.perf_counter() - t0:.4f}s") + + logger.info(f"[UPSERT TIMING] upsert_total: {time.perf_counter() - upsert_start:.4f}s") + return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + def upsert( self, df: pa.Table, @@ -2180,11 +2352,14 @@ def _plan_files_server_side(self) -> Iterable[FileScanTask]: def _plan_files_local(self) -> Iterable[FileScanTask]: """Plan files locally by reading manifests.""" + plan_start = time.perf_counter() + data_entries: list[ManifestEntry] = [] positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER) residual_evaluators: dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator) + manifest_scan_start = time.perf_counter() for manifest_entry in chain.from_iterable(self.scan_plan_helper()): data_file = manifest_entry.data_file if data_file.content == DataFileContent.DATA: @@ -2195,8 +2370,10 @@ def _plan_files_local(self) -> Iterable[FileScanTask]: raise ValueError("PyIceberg does not yet support equality deletes: https://github.com/apache/iceberg/issues/6568") else: raise ValueError(f"Unknown DataFileContent ({data_file.content}): {manifest_entry}") + manifest_scan_end = time.perf_counter() - return [ + task_creation_start = time.perf_counter() + tasks = [ FileScanTask( data_entry.data_file, delete_files=_match_deletes_to_data_file( @@ -2209,6 +2386,17 @@ def _plan_files_local(self) -> Iterable[FileScanTask]: ) for data_entry in data_entries ] + task_creation_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] _plan_files_local: %.4fs (manifest_scan: %.4fs, task_creation: %.4fs, data_files: %d, delete_files: %d)", + time.perf_counter() - plan_start, + manifest_scan_end - manifest_scan_start, + task_creation_end - task_creation_start, + len(data_entries), + len(positional_delete_entries), + ) + return tasks def plan_files(self) -> Iterable[FileScanTask]: """Plans the relevant files by filtering on the PartitionSpecs. @@ -2253,10 +2441,26 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow + func_start = time.perf_counter() + target_schema = schema_to_pyarrow(self.projection()) + + plan_start = time.perf_counter() + file_tasks = self.plan_files() + plan_end = time.perf_counter() + + scan_start = time.perf_counter() batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files()) + ).to_record_batches(file_tasks) + scan_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] to_arrow_batch_reader: (plan_files: %.4fs, scan_setup: %.4fs, total_setup: %.4fs)", + plan_end - plan_start, + scan_end - scan_start, + time.perf_counter() - func_start, + ) return pa.RecordBatchReader.from_batches( target_schema, From 88038d319390f697006b3e1c3ddb7125b75275d4 Mon Sep 17 00:00:00 2001 From: EnyMan Date: Tue, 20 Jan 2026 15:28:10 +0100 Subject: [PATCH 6/8] Add even more logging --- pyiceberg/io/pyarrow.py | 15 ++++++++++ pyiceberg/table/__init__.py | 56 ++++++++++++++++++++++++++++--------- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 84ca3a25d4..0f7e411237 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1604,11 +1604,26 @@ def _task_to_record_batches( filter_start = time.perf_counter() pyarrow_filter = None if bound_row_filter is not AlwaysTrue(): + translate_start = time.perf_counter() translated_row_filter = translate_column_names( bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields ) + translate_end = time.perf_counter() + + bind_start = time.perf_counter() bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) + bind_end = time.perf_counter() + + to_pyarrow_start = time.perf_counter() pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema) + to_pyarrow_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] filter_prep breakdown: (translate: %.4fs, bind: %.4fs, to_pyarrow: %.4fs)", + translate_end - translate_start, + bind_end - bind_start, + to_pyarrow_end - to_pyarrow_start, + ) file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 3eef488a81..5d516e7af2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2373,26 +2373,34 @@ def _plan_files_local(self) -> Iterable[FileScanTask]: manifest_scan_end = time.perf_counter() task_creation_start = time.perf_counter() - tasks = [ - FileScanTask( - data_entry.data_file, - delete_files=_match_deletes_to_data_file( - data_entry, - positional_delete_entries, - ), - residual=residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( - data_entry.data_file.partition - ), + tasks = [] + total_delete_match_time = 0.0 + total_residual_time = 0.0 + + for data_entry in data_entries: + delete_match_start = time.perf_counter() + delete_files = _match_deletes_to_data_file(data_entry, positional_delete_entries) + delete_match_end = time.perf_counter() + total_delete_match_time += delete_match_end - delete_match_start + + residual_start = time.perf_counter() + residual = residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( + data_entry.data_file.partition ) - for data_entry in data_entries - ] + residual_end = time.perf_counter() + total_residual_time += residual_end - residual_start + + tasks.append(FileScanTask(data_entry.data_file, delete_files=delete_files, residual=residual)) + task_creation_end = time.perf_counter() logger.info( - "[SCAN TIMING] _plan_files_local: %.4fs (manifest_scan: %.4fs, task_creation: %.4fs, data_files: %d, delete_files: %d)", + "[SCAN TIMING] _plan_files_local: %.4fs (manifest_scan: %.4fs, task_creation: %.4fs [delete_match: %.4fs, residual: %.4fs], data_files: %d, delete_files: %d)", time.perf_counter() - plan_start, manifest_scan_end - manifest_scan_start, task_creation_end - task_creation_start, + total_delete_match_time, + total_residual_time, len(data_entries), len(positional_delete_entries), ) @@ -2441,10 +2449,32 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow + def count_expression_nodes(expr: BooleanExpression) -> tuple[int, int, int]: + """Count (total_nodes, or_count, and_count) in expression tree.""" + if isinstance(expr, Or): + left_total, left_or, left_and = count_expression_nodes(expr.left) + right_total, right_or, right_and = count_expression_nodes(expr.right) + return (left_total + right_total + 1, left_or + right_or + 1, left_and + right_and) + elif isinstance(expr, And): + left_total, left_or, left_and = count_expression_nodes(expr.left) + right_total, right_or, right_and = count_expression_nodes(expr.right) + return (left_total + right_total + 1, left_or + right_or, left_and + right_and + 1) + else: + return (1, 0, 0) + func_start = time.perf_counter() target_schema = schema_to_pyarrow(self.projection()) + # Log filter complexity + total_nodes, or_count, and_count = count_expression_nodes(self.row_filter) + logger.info( + "[SCAN TIMING] row_filter complexity: (total_nodes: %d, or_count: %d, and_count: %d)", + total_nodes, + or_count, + and_count, + ) + plan_start = time.perf_counter() file_tasks = self.plan_files() plan_end = time.perf_counter() From 3ae5edd8c543ba7a6a1b258db0e70b33c3874e1f Mon Sep 17 00:00:00 2001 From: EnyMan Date: Tue, 20 Jan 2026 15:46:57 +0100 Subject: [PATCH 7/8] Filter skimming --- pyiceberg/table/__init__.py | 172 --------------------------------- pyiceberg/table/upsert_util.py | 112 +++++++++++++++++++-- 2 files changed, 102 insertions(+), 182 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5d516e7af2..c68acc1097 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -766,178 +766,6 @@ def delete( if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records", stacklevel=2) - def upsert_unoptimized( - self, - df: pa.Table, - join_cols: list[str] | None = None, - when_matched_update_all: bool = True, - when_not_matched_insert_all: bool = True, - case_sensitive: bool = True, - branch: str | None = MAIN_BRANCH, - snapshot_properties: dict[str, str] = EMPTY_DICT, - ) -> UpsertResult: - """Shorthand API for performing an upsert to an iceberg table. - - Args: - - df: The input dataframe to upsert with the table's data. - join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. - when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing - when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table - case_sensitive: Bool indicating if the match should be case-sensitive - branch: Branch Reference to run the upsert operation - snapshot_properties: Custom properties to be added to the snapshot summary - - To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids - - Example Use Cases: - Case 1: Both Parameters = True (Full Upsert) - Existing row found → Update it - New row found → Insert it - - Case 2: when_matched_update_all = False, when_not_matched_insert_all = True - Existing row found → Do nothing (no updates) - New row found → Insert it - - Case 3: when_matched_update_all = True, when_not_matched_insert_all = False - Existing row found → Update it - New row found → Do nothing (no inserts) - - Case 4: Both Parameters = False (No Merge Effect) - Existing row found → Do nothing - New row found → Do nothing - (Function effectively does nothing) - - - Returns: - An UpsertResult class (contains details of rows updated and inserted) - """ - upsert_start = time.perf_counter() - try: - import pyarrow as pa # noqa: F401 - except ModuleNotFoundError as e: - raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - - from pyiceberg.io.pyarrow import expression_to_pyarrow - from pyiceberg.table import upsert_util - - if join_cols is None: - join_cols = [] - for field_id in self.table_metadata.schema().identifier_field_ids: - col = self.table_metadata.schema().find_column_name(field_id) - if col is not None: - join_cols.append(col) - else: - raise ValueError(f"Field-ID could not be found: {join_cols}") - - if len(join_cols) == 0: - raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") - - if not when_matched_update_all and not when_not_matched_insert_all: - raise ValueError("no upsert options selected...exiting") - - if upsert_util.has_duplicate_rows(df, join_cols): - raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") - - from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible - - downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_pyarrow_schema_compatible( - self.table_metadata.schema(), - provided_schema=df.schema, - downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, - format_version=self.table_metadata.format_version, - ) - - # get list of rows that exist so we don't have to load the entire target table - t0 = time.perf_counter() - matched_predicate = upsert_util.create_match_filter(df, join_cols) - logger.info(f"[UPSERT TIMING] create_match_filter (initial): {time.perf_counter() - t0:.4f}s") - - # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. - - t0 = time.perf_counter() - matched_iceberg_record_batches_scan = DataScan( - table_metadata=self.table_metadata, - io=self._table.io, - row_filter=matched_predicate, - case_sensitive=case_sensitive, - ) - - if branch in self.table_metadata.refs: - matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) - logger.info(f"[UPSERT TIMING] datascan_setup: {time.perf_counter() - t0:.4f}s") - - t0 = time.perf_counter() - matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() - logger.info(f"[UPSERT TIMING] to_arrow_batch_reader: {time.perf_counter() - t0:.4f}s") - - batches_to_overwrite = [] - overwrite_predicates = [] - rows_to_insert = df - - batch_loop_start = time.perf_counter() - batch_idx = 0 - for batch in matched_iceberg_record_batches: - rows = pa.Table.from_batches([batch]) - - if when_matched_update_all: - # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed - # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed - # this extra step avoids unnecessary IO and writes - t0 = time.perf_counter() - rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols) - logger.info(f"[UPSERT TIMING] batch {batch_idx}: get_rows_to_update: {time.perf_counter() - t0:.4f}s") - - if len(rows_to_update) > 0: - # build the match predicate filter - t0 = time.perf_counter() - overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) - logger.info(f"[UPSERT TIMING] batch {batch_idx}: create_match_filter (overwrite): {time.perf_counter() - t0:.4f}s") - - batches_to_overwrite.append(rows_to_update) - overwrite_predicates.append(overwrite_mask_predicate) - - if when_not_matched_insert_all: - t0 = time.perf_counter() - expr_match = upsert_util.create_match_filter(rows, join_cols) - logger.info(f"[UPSERT TIMING] batch {batch_idx}: create_match_filter (insert): {time.perf_counter() - t0:.4f}s") - expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) - - # Filter rows per batch. - t0 = time.perf_counter() - rows_to_insert = rows_to_insert.filter(~expr_match_arrow) - logger.info(f"[UPSERT TIMING] batch {batch_idx}: filter_rows_to_insert: {time.perf_counter() - t0:.4f}s") - - batch_idx += 1 - logger.info(f"[UPSERT TIMING] batch_loop_total: {time.perf_counter() - batch_loop_start:.4f}s ({batch_idx} iterations)") - - update_row_cnt = 0 - insert_row_cnt = 0 - - if batches_to_overwrite: - rows_to_update = pa.concat_tables(batches_to_overwrite) - update_row_cnt = len(rows_to_update) - t0 = time.perf_counter() - self.overwrite( - rows_to_update, - overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], - branch=branch, - snapshot_properties=snapshot_properties, - ) - logger.info(f"[UPSERT TIMING] overwrite_operation: {time.perf_counter() - t0:.4f}s") - - if when_not_matched_insert_all: - insert_row_cnt = len(rows_to_insert) - if rows_to_insert: - t0 = time.perf_counter() - self.append(rows_to_insert, branch=branch, snapshot_properties=snapshot_properties) - logger.info(f"[UPSERT TIMING] append_operation: {time.perf_counter() - t0:.4f}s") - - logger.info(f"[UPSERT TIMING] upsert_total: {time.perf_counter() - upsert_start:.4f}s") - return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) - def upsert( self, df: pa.Table, diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 71c505321d..a2762d2e27 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -27,12 +27,20 @@ from pyiceberg.expressions import ( AlwaysFalse, + AlwaysTrue, + And, BooleanExpression, EqualTo, + GreaterThanOrEqual, In, + LessThanOrEqual, Or, ) +# Threshold for switching from In() predicate to range-based or no filter +# When unique keys exceed this, the In() predicate becomes too expensive to process +LARGE_FILTER_THRESHOLD = 10_000 + def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: """ @@ -62,32 +70,116 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre return Or(*filters) +def _is_numeric_type(arrow_type: pa.DataType) -> bool: + """Check if a PyArrow type is numeric (suitable for range filtering).""" + return pa.types.is_integer(arrow_type) or pa.types.is_floating(arrow_type) + + +def _create_range_filter(col_name: str, values: pa.Array) -> BooleanExpression: + """Create a min/max range filter for a numeric column.""" + min_val = pc.min(values).as_py() + max_val = pc.max(values).as_py() + return And(GreaterThanOrEqual(col_name, min_val), LessThanOrEqual(col_name, max_val)) + + def create_coarse_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: """ Create a coarse Iceberg BooleanExpression filter for initial row scanning. - For single-column keys, uses an efficient In() predicate (exact match). - For composite keys, uses In() per column as a coarse filter (AND of In() predicates), - which may return false positives but is much more efficient than exact matching. + For small datasets (< LARGE_FILTER_THRESHOLD unique keys): + - Single-column keys: uses In() predicate + - Composite keys: uses AND of In() predicates per column + + For large datasets (>= LARGE_FILTER_THRESHOLD unique keys): + - Single numeric column with dense IDs: uses min/max range filter + - Otherwise: returns AlwaysTrue() to skip filtering (full scan) This function should only be used for initial scans where exact matching happens downstream (e.g., in get_rows_to_update() via the join operation). """ unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + num_unique_keys = len(unique_keys) - if len(unique_keys) == 0: + if num_unique_keys == 0: return AlwaysFalse() + # For small datasets, use the standard In() approach + if num_unique_keys < LARGE_FILTER_THRESHOLD: + if len(join_cols) == 1: + return In(join_cols[0], unique_keys[0].to_pylist()) + else: + column_filters = [] + for col in join_cols: + unique_values = pc.unique(unique_keys[col]).to_pylist() + column_filters.append(In(col, unique_values)) + return functools.reduce(operator.and_, column_filters) + + # For large datasets, use optimized strategies + logger.info( + f"Large dataset detected ({num_unique_keys} unique keys >= {LARGE_FILTER_THRESHOLD} threshold), " + "using optimized filter strategy" + ) + if len(join_cols) == 1: - return In(join_cols[0], unique_keys[0].to_pylist()) + col_name = join_cols[0] + col_data = unique_keys[col_name] + col_type = col_data.type + + # For numeric columns, check if range filter is efficient (dense IDs) + if _is_numeric_type(col_type): + min_val = pc.min(col_data).as_py() + max_val = pc.max(col_data).as_py() + value_range = max_val - min_val + 1 + density = num_unique_keys / value_range if value_range > 0 else 0 + + # If IDs are dense (>10% coverage of the range), use range filter + # Otherwise, range filter would read too much irrelevant data + if density > 0.1: + logger.info( + f"Using range filter for column '{col_name}': " + f"min={min_val}, max={max_val}, density={density:.2%}" + ) + return _create_range_filter(col_name, col_data) + else: + logger.info( + f"Skipping filter (sparse IDs, density={density:.2%}): " + f"full scan will be performed" + ) + return AlwaysTrue() + else: + # Non-numeric single column with many values - skip filter + logger.info( + f"Skipping filter for non-numeric column '{col_name}' with {num_unique_keys} values: " + "full scan will be performed" + ) + return AlwaysTrue() else: - # For composite keys: use In() per column as a coarse filter - # This is more efficient than creating Or(And(...), And(...), ...) for each row - # May include false positives, but fine-grained matching happens downstream + # Composite keys with many values - use range filters for numeric columns where possible column_filters = [] for col in join_cols: - unique_values = pc.unique(unique_keys[col]).to_pylist() - column_filters.append(In(col, unique_values)) + col_data = unique_keys[col] + col_type = col_data.type + unique_values = pc.unique(col_data) + + if _is_numeric_type(col_type) and len(unique_values) >= LARGE_FILTER_THRESHOLD: + # Use range filter for large numeric columns + min_val = pc.min(unique_values).as_py() + max_val = pc.max(unique_values).as_py() + value_range = max_val - min_val + 1 + density = len(unique_values) / value_range if value_range > 0 else 0 + + if density > 0.1: + logger.info(f"Using range filter for composite key column '{col}': density={density:.2%}") + column_filters.append(_create_range_filter(col, unique_values)) + else: + # Sparse numeric column - still use In() as it's part of composite key + column_filters.append(In(col, unique_values.to_pylist())) + else: + # Small column or non-numeric - use In() + column_filters.append(In(col, unique_values.to_pylist())) + + if len(column_filters) == 0: + return AlwaysTrue() return functools.reduce(operator.and_, column_filters) From 8355945cace49643bec7c26752838d6f44d5cce2 Mon Sep 17 00:00:00 2001 From: EnyMan Date: Tue, 20 Jan 2026 21:49:57 +0100 Subject: [PATCH 8/8] Add more upsert tests --- tests/table/test_upsert.py | 521 +++++++++++++++++++++++++++++++++++++ 1 file changed, 521 insertions(+) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index e4b2fd4377..69c258590c 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -1117,3 +1117,524 @@ def test_vectorized_comparison_empty_struct_with_nulls() -> None: target = pa.array([None, None], type=empty_struct_type) diff = _compare_columns_vectorized(source, target) assert diff.to_pylist() == [False, False] + + +# ============================================================================ +# Tests for create_coarse_match_filter and _is_numeric_type +# ============================================================================ + + +@pytest.mark.parametrize( + "dtype,expected_numeric", + [ + (pa.int8(), True), + (pa.int16(), True), + (pa.int32(), True), + (pa.int64(), True), + (pa.uint8(), True), + (pa.uint16(), True), + (pa.uint32(), True), + (pa.uint64(), True), + (pa.float16(), True), + (pa.float32(), True), + (pa.float64(), True), + (pa.string(), False), + (pa.binary(), False), + (pa.date32(), False), + (pa.date64(), False), + (pa.timestamp("us"), False), + (pa.timestamp("ns"), False), + (pa.decimal128(10, 2), False), + (pa.decimal256(20, 4), False), + (pa.bool_(), False), + (pa.large_string(), False), + (pa.large_binary(), False), + ], +) +def test_is_numeric_type(dtype: pa.DataType, expected_numeric: bool) -> None: + """Test that _is_numeric_type correctly identifies all numeric types.""" + from pyiceberg.table.upsert_util import _is_numeric_type + + assert _is_numeric_type(dtype) == expected_numeric + + +# ============================================================================ +# Thresholding Tests (Small vs Large Datasets) +# ============================================================================ + + +def test_coarse_match_filter_small_dataset_uses_in_filter() -> None: + """Test that small datasets (< 10,000 unique keys) use In() filter.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dataset with 100 unique keys (well below threshold) + num_keys = 100 + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert num_keys < LARGE_FILTER_THRESHOLD + assert isinstance(result, In) + assert result.term.name == "id" + assert len(result.literals) == num_keys + + +def test_coarse_match_filter_threshold_boundary_uses_in_filter() -> None: + """Test that datasets at threshold - 1 (9,999 unique keys) still use In() filter.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dataset with exactly threshold - 1 unique keys + num_keys = LARGE_FILTER_THRESHOLD - 1 + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert isinstance(result, In) + assert result.term.name == "id" + assert len(result.literals) == num_keys + + +def test_coarse_match_filter_above_threshold_uses_optimized_filter() -> None: + """Test that datasets >= 10,000 unique keys use optimized filter strategy.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dense dataset (consecutive IDs) with exactly threshold unique keys + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Dense IDs should use range filter (And of GreaterThanOrEqual and LessThanOrEqual) + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == 0 + assert result.right.literal.value == num_keys - 1 + + +def test_coarse_match_filter_large_dataset() -> None: + """Test that large datasets (100,000 unique keys) use optimized filter.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dense dataset with 100,000 unique keys + num_keys = 100_000 + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert num_keys >= LARGE_FILTER_THRESHOLD + # Dense IDs should use range filter + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +# ============================================================================ +# Density Calculation Tests +# ============================================================================ + + +def test_coarse_match_filter_dense_ids_use_range_filter() -> None: + """Test that dense IDs (density > 10%) use range filter.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dense IDs: all values from 0 to N-1 (100% density) + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / (9999 - 0 + 1) = 100% + # Should use range filter + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +def test_coarse_match_filter_moderately_dense_ids_use_range_filter() -> None: + """Test that moderately dense IDs (50% density) use range filter.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create IDs: 0, 2, 4, 6, ... (every other number) - 50% density + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(0, num_keys * 2, 2)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / (19998 - 0 + 1) ~= 50% + # Should use range filter since density > 10% + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +def test_coarse_match_filter_sparse_ids_use_always_true() -> None: + """Test that sparse IDs (density <= 10%) use AlwaysTrue (full scan).""" + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create sparse IDs: values spread across a large range + # 10,000 values in range of ~110,000 = ~9% density + num_keys = LARGE_FILTER_THRESHOLD + ids = list(range(0, num_keys * 11, 11)) # 0, 11, 22, 33, ... + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density ~= 10000 / ((10000-1)*11 + 1) = 9.09% < 10% + # Should use AlwaysTrue (full scan) + assert isinstance(result, AlwaysTrue) + + +def test_coarse_match_filter_density_boundary_at_10_percent() -> None: + """Test exact 10% boundary density behavior.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create IDs at exactly ~10% density + # 10,000 values in range of 100,000 = exactly 10% + num_keys = LARGE_FILTER_THRESHOLD + # Generate 10,000 values in range [0, 99999] -> density = 10000/100000 = 10% + # Using every 10th value: 0, 10, 20, ... 99990 + ids = list(range(0, num_keys * 10, 10)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / ((num_keys-1)*10 + 1) = 10000 / 99991 ~= 10.001% + # Should use range filter since density > 10% (just barely) + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +def test_coarse_match_filter_very_sparse_ids() -> None: + """Test that very sparse IDs (e.g., 1, 1M, 2M) use AlwaysTrue.""" + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create extremely sparse IDs + num_keys = LARGE_FILTER_THRESHOLD + # Values from 0 to (num_keys-1) * 1000, stepping by 1000 + ids = list(range(0, num_keys * 1000, 1000)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / ((10000-1)*1000 + 1) ~= 0.1% + # Should use AlwaysTrue + assert isinstance(result, AlwaysTrue) + + +# ============================================================================ +# Edge Cases +# ============================================================================ + + +def test_coarse_match_filter_empty_dataset_returns_always_false() -> None: + """Test that empty dataset returns AlwaysFalse.""" + from pyiceberg.expressions import AlwaysFalse + + from pyiceberg.table.upsert_util import create_coarse_match_filter + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict({"id": [], "value": []}, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert isinstance(result, AlwaysFalse) + + +def test_coarse_match_filter_single_value_dataset() -> None: + """Test that single value dataset uses In() or EqualTo() with single value.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import create_coarse_match_filter + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict({"id": [42], "value": [1]}, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # PyIceberg may optimize In() with a single value to EqualTo() + if isinstance(result, In): + assert result.term.name == "id" + assert len(result.literals) == 1 + assert result.literals[0].value == 42 + elif isinstance(result, EqualTo): + assert result.term.name == "id" + assert result.literal.value == 42 + else: + pytest.fail(f"Expected In or EqualTo, got {type(result)}") + + +def test_coarse_match_filter_negative_numbers_range() -> None: + """Test that negative number IDs produce correct min/max range.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dense negative IDs: -10000 to -1 + num_keys = LARGE_FILTER_THRESHOLD + ids = list(range(-num_keys, 0)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Should use range filter with negative values + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == -num_keys # min + assert result.right.literal.value == -1 # max + + +def test_coarse_match_filter_mixed_sign_numbers_range() -> None: + """Test that mixed sign IDs (-500 to 500) produce correct range spanning zero.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create IDs spanning zero: -5000 to 4999 + num_keys = LARGE_FILTER_THRESHOLD + ids = list(range(-num_keys // 2, num_keys // 2)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Should use range filter spanning zero + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == -num_keys // 2 # min + assert result.right.literal.value == num_keys // 2 - 1 # max + + +def test_coarse_match_filter_float_range_filter() -> None: + """Test that float IDs use range filter correctly.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dense float IDs + num_keys = LARGE_FILTER_THRESHOLD + ids = [float(i) for i in range(num_keys)] + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.float64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Should use range filter for float column + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == 0.0 + assert result.right.literal.value == float(num_keys - 1) + + +def test_coarse_match_filter_non_numeric_column_skips_range_filter() -> None: + """Test that non-numeric column with >10k values returns AlwaysTrue.""" + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create string IDs (non-numeric) with many unique values + num_keys = LARGE_FILTER_THRESHOLD + ids = [f"id_{i:05d}" for i in range(num_keys)] + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Non-numeric column with large dataset should use AlwaysTrue + assert isinstance(result, AlwaysTrue) + + +# ============================================================================ +# Composite Key Tests +# ============================================================================ + + +def test_coarse_match_filter_composite_key_small_dataset() -> None: + """Test that composite key with small dataset uses And(In(), In()).""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import create_coarse_match_filter + + # Create a small dataset with composite key + data = { + "a": [1, 2, 3, 1, 2, 3], + "b": ["x", "x", "x", "y", "y", "y"], + "value": [1, 2, 3, 4, 5, 6], + } + schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["a", "b"]) + + # Should be And(In(a), In(b)) + assert isinstance(result, And) + # Check that both children are In() filters + assert "In" in str(result) + + +def test_coarse_match_filter_composite_key_large_numeric_column() -> None: + """Test composite key where one column has >10k unique numeric values.""" + from pyiceberg.expressions import GreaterThanOrEqual, In, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dataset with one large dense numeric column and one small column + num_keys = LARGE_FILTER_THRESHOLD + data = { + "a": list(range(num_keys)), # 10k unique dense values + "b": ["category_1"] * (num_keys // 2) + ["category_2"] * (num_keys // 2), # 2 unique values + "value": list(range(num_keys)), + } + schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["a", "b"]) + + # Should be And of filters for both columns + assert isinstance(result, And) + # Column 'a' (large, dense, numeric) should use range filter + # Column 'b' (small) should use In() + result_str = str(result) + assert "GreaterThanOrEqual" in result_str or "In" in result_str + + +def test_coarse_match_filter_composite_key_mixed_types() -> None: + """Test composite key with mixed numeric and string columns with large dataset.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dataset with large sparse numeric column and large string column + num_keys = LARGE_FILTER_THRESHOLD + # Sparse numeric IDs + ids = list(range(0, num_keys * 100, 100)) + # Many unique strings + strings = [f"str_{i}" for i in range(num_keys)] + data = { + "numeric_id": ids, + "string_id": strings, + "value": list(range(num_keys)), + } + schema = pa.schema([pa.field("numeric_id", pa.int64()), pa.field("string_id", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["numeric_id", "string_id"]) + + # Both columns have large unique values + # numeric_id is sparse (density < 10%), so should use In() + # string_id is non-numeric, so should use In() + assert isinstance(result, And) + + +# ============================================================================ +# Integration Test with Logging Verification +# ============================================================================ + + +def test_coarse_match_filter_logs_warning_for_full_scan(caplog: pytest.LogCaptureFixture) -> None: + """Verify logging when AlwaysTrue is used (indicates full scan will be performed).""" + import logging + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create table with >10k sparse IDs (will trigger AlwaysTrue) + num_keys = LARGE_FILTER_THRESHOLD + # Very sparse IDs + ids = list(range(0, num_keys * 1000, 1000)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + with caplog.at_level(logging.INFO, logger="pyiceberg.table.upsert_util"): + result = create_coarse_match_filter(table, ["id"]) + + # Verify AlwaysTrue is returned + assert isinstance(result, AlwaysTrue) + + # Verify log message about skipping filter + assert any("full scan will be performed" in record.message for record in caplog.records) + + +def test_coarse_match_filter_logs_range_filter_usage(caplog: pytest.LogCaptureFixture) -> None: + """Verify logging when range filter is used for dense IDs.""" + import logging + + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create table with >10k dense IDs (will use range filter) + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + with caplog.at_level(logging.INFO, logger="pyiceberg.table.upsert_util"): + result = create_coarse_match_filter(table, ["id"]) + + # Verify range filter is returned + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + # Verify log message about using range filter + assert any("Using range filter" in record.message for record in caplog.records) + + +def test_coarse_match_filter_logs_large_dataset_detection(caplog: pytest.LogCaptureFixture) -> None: + """Verify logging when large dataset is detected.""" + import logging + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create table with exactly LARGE_FILTER_THRESHOLD keys + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + with caplog.at_level(logging.INFO, logger="pyiceberg.table.upsert_util"): + create_coarse_match_filter(table, ["id"]) + + # Verify log message about large dataset detection + assert any("Large dataset detected" in record.message for record in caplog.records) + assert any(f"{LARGE_FILTER_THRESHOLD}" in record.message for record in caplog.records)