Skip to content
42 changes: 28 additions & 14 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,6 @@ def upsert(
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import expression_to_pyarrow
from pyiceberg.table import upsert_util

if join_cols is None:
Expand Down Expand Up @@ -835,8 +834,10 @@ def upsert(
format_version=self.table_metadata.format_version,
)

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
# Create a coarse filter for the initial scan to reduce the number of rows read.
# This filter is intentionally less precise but faster to evaluate than exact matching.
# Exact key matching happens downstream in get_rows_to_update() via PyArrow joins.
matched_predicate = upsert_util.create_coarse_match_filter(df, join_cols)

# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.

Expand All @@ -854,31 +855,44 @@ def upsert(

batches_to_overwrite = []
overwrite_predicates = []
rows_to_insert = df
# Accumulate matched keys for anti-join insert filtering after the batch loop.
# We only store key columns (not full rows) to minimize memory usage.
matched_target_keys: list[pa.Table] = []

for batch in matched_iceberg_record_batches:
rows = pa.Table.from_batches([batch])

if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
# Check non-key columns to see if values have actually changed.
# We don't want to do a blanket overwrite for matched rows if the
# actual non-key column data hasn't changed - this avoids unnecessary IO and writes.
rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols)

if len(rows_to_update) > 0:
# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)

batches_to_overwrite.append(rows_to_update)
overwrite_predicates.append(overwrite_mask_predicate)

if when_not_matched_insert_all:
expr_match = upsert_util.create_match_filter(rows, join_cols)
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
matched_target_keys.append(rows.select(join_cols))

# Filter rows per batch.
rows_to_insert = rows_to_insert.filter(~expr_match_arrow)
# Use anti-join to find rows to insert. This is more efficient than per-batch
# expression filtering because: (1) we build expressions once, not per batch,
# and (2) PyArrow joins are faster than evaluating large Or(...) expressions.
rows_to_insert = df
if when_not_matched_insert_all and matched_target_keys:
# Combine all matched keys and deduplicate
combined_matched_keys = pa.concat_tables(matched_target_keys).group_by(join_cols).aggregate([])
# Cast matched keys to source schema types for join compatibility
source_key_schema = df.select(join_cols).schema
combined_matched_keys = combined_matched_keys.cast(source_key_schema)
# Use anti-join on key columns only (with row indices) to avoid issues with
# struct/list types in non-key columns that PyArrow join doesn't support
row_indices = pa.chunked_array([pa.array(range(len(df)), type=pa.int64())])
source_keys_with_idx = df.select(join_cols).append_column("__row_idx__", row_indices)
not_matched_keys = source_keys_with_idx.join(combined_matched_keys, keys=join_cols, join_type="left anti")
indices_to_keep = not_matched_keys.column("__row_idx__").combine_chunks()
rows_to_insert = df.take(indices_to_keep)

update_row_cnt = 0
insert_row_cnt = 0
Expand Down
Loading