diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 55ecc7ac93..0f7e411237 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -34,6 +34,7 @@ import operator import os import re +import time import uuid import warnings from abc import ABC, abstractmethod @@ -1575,11 +1576,16 @@ def _task_to_record_batches( format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, ) -> Iterator[pa.RecordBatch]: + task_start = time.perf_counter() + + open_start = time.perf_counter() arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: fragment = arrow_format.make_fragment(fin) physical_schema = fragment.physical_schema + open_end = time.perf_counter() + schema_start = time.perf_counter() # For V1 and V2, we only support Timestamp 'us' in Iceberg Schema, therefore it is reasonable to always cast 'ns' timestamp to 'us' on read. # For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default downcast_ns_timestamp_to_us = ( @@ -1593,14 +1599,31 @@ def _task_to_record_batches( projected_missing_fields = _get_column_projection_values( task.file, projected_schema, table_schema, partition_spec, file_schema.field_ids ) + schema_end = time.perf_counter() + filter_start = time.perf_counter() pyarrow_filter = None if bound_row_filter is not AlwaysTrue(): + translate_start = time.perf_counter() translated_row_filter = translate_column_names( bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields ) + translate_end = time.perf_counter() + + bind_start = time.perf_counter() bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) + bind_end = time.perf_counter() + + to_pyarrow_start = time.perf_counter() pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema) + to_pyarrow_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] filter_prep breakdown: (translate: %.4fs, bind: %.4fs, to_pyarrow: %.4fs)", + translate_end - translate_start, + bind_end - bind_start, + to_pyarrow_end - to_pyarrow_start, + ) file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) @@ -1612,9 +1635,24 @@ def _task_to_record_batches( filter=pyarrow_filter if not positional_deletes else None, columns=[col.name for col in file_project_schema.columns], ) + filter_end = time.perf_counter() + + batch_read_start = time.perf_counter() + batches = list(fragment_scanner.to_batches()) + batch_read_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] _task_to_record_batches %s: (open: %.4fs, schema: %.4fs, filter_prep: %.4fs, batch_read: %.4fs, batches: %d, total: %.4fs)", + task.file.file_path, + open_end - open_start, + schema_end - schema_start, + filter_end - filter_start, + batch_read_end - batch_read_start, + len(batches), + time.perf_counter() - task_start, + ) next_index = 0 - batches = fragment_scanner.to_batches() for batch in batches: next_index = next_index + len(batch) current_index = next_index - len(batch) @@ -1650,9 +1688,14 @@ def _task_to_record_batches( def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]: + func_start = time.perf_counter() + deletes_per_file: dict[str, list[ChunkedArray]] = {} unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) + collect_end = time.perf_counter() + if len(unique_deletes) > 0: + read_start = time.perf_counter() executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map( lambda args: _read_deletes(*args), @@ -1664,6 +1707,20 @@ def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[st deletes_per_file[file].append(arr) else: deletes_per_file[file] = [arr] + read_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] _read_all_delete_files: %.4fs (collect: %.4fs, read: %.4fs, unique_delete_files: %d)", + time.perf_counter() - func_start, + collect_end - func_start, + read_end - read_start, + len(unique_deletes), + ) + else: + logger.info( + "[SCAN TIMING] _read_all_delete_files: %.4fs (no deletes)", + time.perf_counter() - func_start, + ) return deletes_per_file @@ -1773,7 +1830,21 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._io, tasks) + materialize_start = time.perf_counter() + tasks_list = list(tasks) # Force materialization for delete file collection + materialize_end = time.perf_counter() + + delete_read_start = time.perf_counter() + deletes_per_file = _read_all_delete_files(self._io, tasks_list) + delete_read_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] to_record_batches setup: (task_materialize: %.4fs, delete_read: %.4fs, tasks: %d, delete_files: %d)", + materialize_end - materialize_start, + delete_read_end - delete_read_start, + len(tasks_list), + len(deletes_per_file), + ) total_row_count = 0 executor = ExecutorFactory.get_or_create() @@ -1785,7 +1856,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]: return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) limit_reached = False - for batches in executor.map(batches_for_task, tasks): + for batches in executor.map(batches_for_task, tasks_list): for batch in batches: current_batch_size = len(batch) if self._limit is not None and total_row_count + current_batch_size >= self._limit: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b30a1426e7..c68acc1097 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -17,7 +17,9 @@ from __future__ import annotations import itertools +import logging import os +import time import uuid import warnings from abc import ABC, abstractmethod @@ -154,6 +156,8 @@ ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" +logger = logging.getLogger(__name__) + @dataclass() class UpsertResult: @@ -521,6 +525,7 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: + write_start = time.perf_counter() data_files = list( _dataframe_to_data_files( table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io @@ -528,6 +533,8 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, ) for data_file in data_files: append_files.append_data_file(data_file) + write_end = time.perf_counter() + logger.info(f"Append data file writing: {write_end - write_start:.3f}s ({df.shape[0]} rows)") def dynamic_partition_overwrite( self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH @@ -636,21 +643,27 @@ def overwrite( if overwrite_filter != AlwaysFalse(): # Only delete when the filter is != AlwaysFalse + delete_start = time.perf_counter() self.delete( delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties, branch=branch, ) + delete_end = time.perf_counter() + logger.info(f"Overwrite delete operation: {delete_end - delete_start:.3f}s") with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: + write_start = time.perf_counter() data_files = _dataframe_to_data_files( table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io ) for data_file in data_files: append_files.append_data_file(data_file) + write_end = time.perf_counter() + logger.info(f"Overwrite data file writing: {write_end - write_start:.3f}s ({df.shape[0]} rows)") def delete( self, @@ -799,12 +812,13 @@ def upsert( Returns: An UpsertResult class (contains details of rows updated and inserted) """ + upsert_start = time.perf_counter() + try: import pyarrow as pa # noqa: F401 except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - from pyiceberg.io.pyarrow import expression_to_pyarrow from pyiceberg.table import upsert_util if join_cols is None: @@ -835,8 +849,15 @@ def upsert( format_version=self.table_metadata.format_version, ) + setup_end = time.perf_counter() + logger.info(f"Upsert setup (join cols, validation): {setup_end - upsert_start:.3f}s") + # get list of rows that exist so we don't have to load the entire target table - matched_predicate = upsert_util.create_match_filter(df, join_cols) + # Use coarse filter for initial scan - exact matching happens in get_rows_to_update() + matched_predicate = upsert_util.create_coarse_match_filter(df, join_cols) + + coarse_filter_end = time.perf_counter() + logger.info(f"Coarse match filter creation: {coarse_filter_end - setup_end:.3f}s") # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. @@ -850,20 +871,30 @@ def upsert( if branch in self.table_metadata.refs: matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) + scan_start = time.perf_counter() matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() + scan_end = time.perf_counter() + logger.info(f"Scan setup (to_arrow_batch_reader): {scan_end - scan_start:.3f}s") batches_to_overwrite = [] overwrite_predicates = [] - rows_to_insert = df + matched_target_keys: list[pa.Table] = [] # Accumulate matched keys for insert filtering + + batch_loop_start = time.perf_counter() + batch_count = 0 + total_rows_to_update_time = 0.0 for batch in matched_iceberg_record_batches: + batch_count += 1 rows = pa.Table.from_batches([batch]) if when_matched_update_all: # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed # this extra step avoids unnecessary IO and writes + rows_to_update_start = time.perf_counter() rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols) + total_rows_to_update_time += time.perf_counter() - rows_to_update_start if len(rows_to_update) > 0: # build the match predicate filter @@ -872,13 +903,34 @@ def upsert( batches_to_overwrite.append(rows_to_update) overwrite_predicates.append(overwrite_mask_predicate) + # Collect matched keys for insert filtering (will use anti-join after loop) if when_not_matched_insert_all: - expr_match = upsert_util.create_match_filter(rows, join_cols) - expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) + matched_target_keys.append(rows.select(join_cols)) - # Filter rows per batch. - rows_to_insert = rows_to_insert.filter(~expr_match_arrow) + batch_loop_end = time.perf_counter() + logger.info( + f"Batch processing: {batch_loop_end - batch_loop_start:.3f}s " + f"({batch_count} batches, get_rows_to_update total: {total_rows_to_update_time:.3f}s)" + ) + + # Use anti-join to find rows to insert (replaces per-batch expression filtering) + rows_to_insert = df + if when_not_matched_insert_all and matched_target_keys: + filter_start = time.perf_counter() + # Combine all matched keys and deduplicate + combined_matched_keys = pa.concat_tables(matched_target_keys).group_by(join_cols).aggregate([]) + # Cast matched keys to source schema types for join compatibility + source_key_schema = df.select(join_cols).schema + combined_matched_keys = combined_matched_keys.cast(source_key_schema) + # Use anti-join on key columns only (with row indices) to avoid issues with + # struct/list types in non-key columns that PyArrow join doesn't support + row_indices = pa.chunked_array([pa.array(range(len(df)), type=pa.int64())]) + source_keys_with_idx = df.select(join_cols).append_column("__row_idx__", row_indices) + not_matched_keys = source_keys_with_idx.join(combined_matched_keys, keys=join_cols, join_type="left anti") + indices_to_keep = not_matched_keys.column("__row_idx__").combine_chunks() + rows_to_insert = df.take(indices_to_keep) + filter_end = time.perf_counter() + logger.info(f"Insert filtering (anti-join): {filter_end - filter_start:.3f}s ({len(combined_matched_keys)} matched keys)") update_row_cnt = 0 insert_row_cnt = 0 @@ -886,17 +938,26 @@ def upsert( if batches_to_overwrite: rows_to_update = pa.concat_tables(batches_to_overwrite) update_row_cnt = len(rows_to_update) + overwrite_start = time.perf_counter() self.overwrite( rows_to_update, overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], branch=branch, snapshot_properties=snapshot_properties, ) + overwrite_end = time.perf_counter() + logger.info(f"Overwrite: {overwrite_end - overwrite_start:.3f}s ({update_row_cnt} rows)") if when_not_matched_insert_all: insert_row_cnt = len(rows_to_insert) if rows_to_insert: + append_start = time.perf_counter() self.append(rows_to_insert, branch=branch, snapshot_properties=snapshot_properties) + append_end = time.perf_counter() + logger.info(f"Append: {append_end - append_start:.3f}s ({insert_row_cnt} rows)") + + upsert_end = time.perf_counter() + logger.info(f"Total upsert: {upsert_end - upsert_start:.3f}s (updated: {update_row_cnt}, inserted: {insert_row_cnt})") return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) @@ -2119,11 +2180,14 @@ def _plan_files_server_side(self) -> Iterable[FileScanTask]: def _plan_files_local(self) -> Iterable[FileScanTask]: """Plan files locally by reading manifests.""" + plan_start = time.perf_counter() + data_entries: list[ManifestEntry] = [] positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER) residual_evaluators: dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator) + manifest_scan_start = time.perf_counter() for manifest_entry in chain.from_iterable(self.scan_plan_helper()): data_file = manifest_entry.data_file if data_file.content == DataFileContent.DATA: @@ -2134,20 +2198,41 @@ def _plan_files_local(self) -> Iterable[FileScanTask]: raise ValueError("PyIceberg does not yet support equality deletes: https://github.com/apache/iceberg/issues/6568") else: raise ValueError(f"Unknown DataFileContent ({data_file.content}): {manifest_entry}") - - return [ - FileScanTask( - data_entry.data_file, - delete_files=_match_deletes_to_data_file( - data_entry, - positional_delete_entries, - ), - residual=residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( - data_entry.data_file.partition - ), + manifest_scan_end = time.perf_counter() + + task_creation_start = time.perf_counter() + tasks = [] + total_delete_match_time = 0.0 + total_residual_time = 0.0 + + for data_entry in data_entries: + delete_match_start = time.perf_counter() + delete_files = _match_deletes_to_data_file(data_entry, positional_delete_entries) + delete_match_end = time.perf_counter() + total_delete_match_time += delete_match_end - delete_match_start + + residual_start = time.perf_counter() + residual = residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( + data_entry.data_file.partition ) - for data_entry in data_entries - ] + residual_end = time.perf_counter() + total_residual_time += residual_end - residual_start + + tasks.append(FileScanTask(data_entry.data_file, delete_files=delete_files, residual=residual)) + + task_creation_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] _plan_files_local: %.4fs (manifest_scan: %.4fs, task_creation: %.4fs [delete_match: %.4fs, residual: %.4fs], data_files: %d, delete_files: %d)", + time.perf_counter() - plan_start, + manifest_scan_end - manifest_scan_start, + task_creation_end - task_creation_start, + total_delete_match_time, + total_residual_time, + len(data_entries), + len(positional_delete_entries), + ) + return tasks def plan_files(self) -> Iterable[FileScanTask]: """Plans the relevant files by filtering on the PartitionSpecs. @@ -2192,10 +2277,48 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow + def count_expression_nodes(expr: BooleanExpression) -> tuple[int, int, int]: + """Count (total_nodes, or_count, and_count) in expression tree.""" + if isinstance(expr, Or): + left_total, left_or, left_and = count_expression_nodes(expr.left) + right_total, right_or, right_and = count_expression_nodes(expr.right) + return (left_total + right_total + 1, left_or + right_or + 1, left_and + right_and) + elif isinstance(expr, And): + left_total, left_or, left_and = count_expression_nodes(expr.left) + right_total, right_or, right_and = count_expression_nodes(expr.right) + return (left_total + right_total + 1, left_or + right_or, left_and + right_and + 1) + else: + return (1, 0, 0) + + func_start = time.perf_counter() + target_schema = schema_to_pyarrow(self.projection()) + + # Log filter complexity + total_nodes, or_count, and_count = count_expression_nodes(self.row_filter) + logger.info( + "[SCAN TIMING] row_filter complexity: (total_nodes: %d, or_count: %d, and_count: %d)", + total_nodes, + or_count, + and_count, + ) + + plan_start = time.perf_counter() + file_tasks = self.plan_files() + plan_end = time.perf_counter() + + scan_start = time.perf_counter() batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files()) + ).to_record_batches(file_tasks) + scan_end = time.perf_counter() + + logger.info( + "[SCAN TIMING] to_arrow_batch_reader: (plan_files: %.4fs, scan_setup: %.4fs, total_setup: %.4fs)", + plan_end - plan_start, + scan_end - scan_start, + time.perf_counter() - func_start, + ) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 6f32826eb0..a2762d2e27 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -15,7 +15,11 @@ # specific language governing permissions and limitations # under the License. import functools +import logging import operator +import time + +logger = logging.getLogger(__name__) import pyarrow as pa from pyarrow import Table as pyarrow_table @@ -23,16 +27,34 @@ from pyiceberg.expressions import ( AlwaysFalse, + AlwaysTrue, + And, BooleanExpression, EqualTo, + GreaterThanOrEqual, In, + LessThanOrEqual, Or, ) +# Threshold for switching from In() predicate to range-based or no filter +# When unique keys exceed this, the In() predicate becomes too expensive to process +LARGE_FILTER_THRESHOLD = 10_000 + def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: + """ + Create an Iceberg BooleanExpression filter that exactly matches rows based on join columns. + + For single-column keys, uses an efficient In() predicate. + For composite keys, creates Or(And(...), And(...), ...) for exact row matching. + This function should be used when exact matching is required (e.g., overwrite, insert filtering). + """ unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + if len(unique_keys) == 0: + return AlwaysFalse() + if len(join_cols) == 1: return In(join_cols[0], unique_keys[0].to_pylist()) else: @@ -48,18 +70,191 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre return Or(*filters) +def _is_numeric_type(arrow_type: pa.DataType) -> bool: + """Check if a PyArrow type is numeric (suitable for range filtering).""" + return pa.types.is_integer(arrow_type) or pa.types.is_floating(arrow_type) + + +def _create_range_filter(col_name: str, values: pa.Array) -> BooleanExpression: + """Create a min/max range filter for a numeric column.""" + min_val = pc.min(values).as_py() + max_val = pc.max(values).as_py() + return And(GreaterThanOrEqual(col_name, min_val), LessThanOrEqual(col_name, max_val)) + + +def create_coarse_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: + """ + Create a coarse Iceberg BooleanExpression filter for initial row scanning. + + For small datasets (< LARGE_FILTER_THRESHOLD unique keys): + - Single-column keys: uses In() predicate + - Composite keys: uses AND of In() predicates per column + + For large datasets (>= LARGE_FILTER_THRESHOLD unique keys): + - Single numeric column with dense IDs: uses min/max range filter + - Otherwise: returns AlwaysTrue() to skip filtering (full scan) + + This function should only be used for initial scans where exact matching happens + downstream (e.g., in get_rows_to_update() via the join operation). + """ + unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + num_unique_keys = len(unique_keys) + + if num_unique_keys == 0: + return AlwaysFalse() + + # For small datasets, use the standard In() approach + if num_unique_keys < LARGE_FILTER_THRESHOLD: + if len(join_cols) == 1: + return In(join_cols[0], unique_keys[0].to_pylist()) + else: + column_filters = [] + for col in join_cols: + unique_values = pc.unique(unique_keys[col]).to_pylist() + column_filters.append(In(col, unique_values)) + return functools.reduce(operator.and_, column_filters) + + # For large datasets, use optimized strategies + logger.info( + f"Large dataset detected ({num_unique_keys} unique keys >= {LARGE_FILTER_THRESHOLD} threshold), " + "using optimized filter strategy" + ) + + if len(join_cols) == 1: + col_name = join_cols[0] + col_data = unique_keys[col_name] + col_type = col_data.type + + # For numeric columns, check if range filter is efficient (dense IDs) + if _is_numeric_type(col_type): + min_val = pc.min(col_data).as_py() + max_val = pc.max(col_data).as_py() + value_range = max_val - min_val + 1 + density = num_unique_keys / value_range if value_range > 0 else 0 + + # If IDs are dense (>10% coverage of the range), use range filter + # Otherwise, range filter would read too much irrelevant data + if density > 0.1: + logger.info( + f"Using range filter for column '{col_name}': " + f"min={min_val}, max={max_val}, density={density:.2%}" + ) + return _create_range_filter(col_name, col_data) + else: + logger.info( + f"Skipping filter (sparse IDs, density={density:.2%}): " + f"full scan will be performed" + ) + return AlwaysTrue() + else: + # Non-numeric single column with many values - skip filter + logger.info( + f"Skipping filter for non-numeric column '{col_name}' with {num_unique_keys} values: " + "full scan will be performed" + ) + return AlwaysTrue() + else: + # Composite keys with many values - use range filters for numeric columns where possible + column_filters = [] + for col in join_cols: + col_data = unique_keys[col] + col_type = col_data.type + unique_values = pc.unique(col_data) + + if _is_numeric_type(col_type) and len(unique_values) >= LARGE_FILTER_THRESHOLD: + # Use range filter for large numeric columns + min_val = pc.min(unique_values).as_py() + max_val = pc.max(unique_values).as_py() + value_range = max_val - min_val + 1 + density = len(unique_values) / value_range if value_range > 0 else 0 + + if density > 0.1: + logger.info(f"Using range filter for composite key column '{col}': density={density:.2%}") + column_filters.append(_create_range_filter(col, unique_values)) + else: + # Sparse numeric column - still use In() as it's part of composite key + column_filters.append(In(col, unique_values.to_pylist())) + else: + # Small column or non-numeric - use In() + column_filters.append(In(col, unique_values.to_pylist())) + + if len(column_filters) == 0: + return AlwaysTrue() + return functools.reduce(operator.and_, column_filters) + + def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: """Check for duplicate rows in a PyArrow table based on the join columns.""" return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0 +def _compare_columns_vectorized( + source_col: pa.Array | pa.ChunkedArray, target_col: pa.Array | pa.ChunkedArray +) -> pa.Array: + """ + Vectorized comparison of two columns, returning a boolean array where True means values differ. + + Handles struct types recursively by comparing each nested field. + Handles null values correctly: null != non-null is True, null == null is True (no update needed). + """ + col_type = source_col.type + + if pa.types.is_struct(col_type): + # Handle struct-level nulls first + source_null = pc.is_null(source_col) + target_null = pc.is_null(target_col) + struct_null_diff = pc.xor(source_null, target_null) # Different if exactly one is null + + # PyArrow cannot directly compare struct columns, so we recursively compare each field + diff_masks = [] + for i, field in enumerate(col_type): + src_field = pc.struct_field(source_col, [i]) + tgt_field = pc.struct_field(target_col, [i]) + field_diff = _compare_columns_vectorized(src_field, tgt_field) + diff_masks.append(field_diff) + + if not diff_masks: + # Empty struct - only null differences matter + return struct_null_diff + + # Combine field differences with struct-level null differences + field_diff = functools.reduce(pc.or_, diff_masks) + return pc.or_(field_diff, struct_null_diff) + + elif pa.types.is_list(col_type) or pa.types.is_large_list(col_type) or pa.types.is_fixed_size_list(col_type) or pa.types.is_map(col_type): + # For list/map types, fall back to Python comparison as PyArrow doesn't support vectorized comparison + # This is still faster than the original row-by-row approach since we batch the conversion + source_py = source_col.to_pylist() + target_py = target_col.to_pylist() + return pa.array([s != t for s, t in zip(source_py, target_py, strict=True)], type=pa.bool_()) + + else: + # For primitive types, use vectorized not_equal + # Handle nulls: not_equal returns null when comparing with null + # We need: null vs non-null = different (True), null vs null = same (False) + diff = pc.not_equal(source_col, target_col) + source_null = pc.is_null(source_col) + target_null = pc.is_null(target_col) + + # XOR of null masks: True if exactly one is null (meaning they differ) + null_diff = pc.xor(source_null, target_null) + + # Combine: different if values differ OR exactly one is null + # Fill null comparison results with False (both non-null but comparison returned null shouldn't happen, + # but if it does, treat as no difference) + diff_filled = pc.fill_null(diff, False) + return pc.or_(diff_filled, null_diff) + + def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table: """ Return a table with rows that need to be updated in the target table based on the join columns. + Uses vectorized PyArrow operations for efficient comparison, avoiding row-by-row Python loops. The table is joined on the identifier columns, and then checked if there are any updated rows. - Those are selected and everything is renamed correctly. """ + func_start = time.perf_counter() + all_columns = set(source_table.column_names) join_cols_set = set(join_cols) @@ -69,13 +264,15 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols raise ValueError("Target table has duplicate rows, aborting upsert") if len(target_table) == 0: - # When the target table is empty, there is nothing to update :) + # When the target table is empty, there is nothing to update + logger.info(f"get_rows_to_update: {time.perf_counter() - func_start:.3f}s (empty target table)") + return source_table.schema.empty_table() + + if len(non_key_cols) == 0: + # No non-key columns to compare, all matched rows are "updates" but with no changes + logger.info(f"get_rows_to_update: {time.perf_counter() - func_start:.3f}s (no non-key columns)") return source_table.schema.empty_table() - # We need to compare non_key_cols in Python as PyArrow - # 1. Cannot do a join when non-join columns have complex types - # 2. Cannot compare columns with complex types - # See: https://github.com/apache/arrow/issues/35785 SOURCE_INDEX_COLUMN_NAME = "__source_index" TARGET_INDEX_COLUMN_NAME = "__target_index" @@ -88,6 +285,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # Step 1: Prepare source index with join keys and a marker index # Cast to target table schema, so we can do the join # See: https://github.com/apache/arrow/issues/37542 + index_start = time.perf_counter() source_index = ( source_table.cast(target_table.schema) .select(join_cols_set) @@ -96,29 +294,58 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # Step 2: Prepare target index with join keys and a marker target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) + index_end = time.perf_counter() # Step 3: Perform an inner join to find which rows from source exist in target + join_start = time.perf_counter() matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + join_end = time.perf_counter() + + if len(matching_indices) == 0: + # No matching rows found + logger.info( + f"get_rows_to_update: {time.perf_counter() - func_start:.3f}s " + f"(index prep: {index_end - index_start:.3f}s, join: {join_end - join_start:.3f}s, " + f"matched: 0, to_update: 0)" + ) + return source_table.schema.empty_table() + + # Step 4: Take matched rows in batch (vectorized - single operation) + take_start = time.perf_counter() + source_indices = matching_indices[SOURCE_INDEX_COLUMN_NAME] + target_indices = matching_indices[TARGET_INDEX_COLUMN_NAME] + + matched_source = source_table.take(source_indices) + matched_target = target_table.take(target_indices) + take_end = time.perf_counter() + + # Step 5: Vectorized comparison per column + compare_start = time.perf_counter() + diff_masks = [] + for col in non_key_cols: + source_col = matched_source.column(col) + target_col = matched_target.column(col) + col_diff = _compare_columns_vectorized(source_col, target_col) + diff_masks.append(col_diff) + + # Step 6: Combine masks with OR (any column different = needs update) + combined_mask = functools.reduce(pc.or_, diff_masks) + compare_end = time.perf_counter() + + # Step 7: Filter to get indices of rows that need updating + to_update_indices = pc.filter(source_indices, combined_mask) + + func_end = time.perf_counter() + logger.info( + f"get_rows_to_update: {func_end - func_start:.3f}s " + f"(index prep: {index_end - index_start:.3f}s, " + f"join: {join_end - join_start:.3f}s, " + f"take: {take_end - take_start:.3f}s, " + f"compare: {compare_end - compare_start:.3f}s, " + f"matched: {len(matching_indices)}, to_update: {len(to_update_indices)})" + ) - # Step 4: Compare all rows using Python - to_update_indices = [] - for source_idx, target_idx in zip( - matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), - matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist(), - strict=True, - ): - source_row = source_table.slice(source_idx, 1) - target_row = target_table.slice(target_idx, 1) - - for key in non_key_cols: - source_val = source_row.column(key)[0].as_py() - target_val = target_row.column(key)[0].as_py() - if source_val != target_val: - to_update_indices.append(source_idx) - break - - # Step 5: Take rows from source table using the indices and cast to target schema - if to_update_indices: + if len(to_update_indices) > 0: return source_table.take(to_update_indices) else: return source_table.schema.empty_table() diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 9bc61799e4..69c258590c 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -885,3 +885,756 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None: for snapshot in snapshots[initial_snapshot_count:]: assert snapshot.summary is not None assert snapshot.summary.additional_properties.get("test_prop") == "test_value" + + +def test_coarse_match_filter_composite_key() -> None: + """ + Test that create_coarse_match_filter produces efficient In() predicates for composite keys. + """ + from pyiceberg.table.upsert_util import create_coarse_match_filter, create_match_filter + from pyiceberg.expressions import Or, And, In + + # Create a table with composite key that has overlapping values + # (1, 'x'), (2, 'y'), (1, 'z') - exact filter should have 3 conditions + # coarse filter should have In(a, [1,2]) AND In(b, ['x','y','z']) + data = [ + {"a": 1, "b": "x", "val": 1}, + {"a": 2, "b": "y", "val": 2}, + {"a": 1, "b": "z", "val": 3}, + ] + schema = pa.schema([pa.field("a", pa.int32()), pa.field("b", pa.string()), pa.field("val", pa.int32())]) + table = pa.Table.from_pylist(data, schema=schema) + + exact_filter = create_match_filter(table, ["a", "b"]) + coarse_filter = create_coarse_match_filter(table, ["a", "b"]) + + # Exact filter is an Or of And conditions + assert isinstance(exact_filter, Or) + + # Coarse filter is an And of In conditions + assert isinstance(coarse_filter, And) + assert "In" in str(coarse_filter) + + +def test_vectorized_comparison_primitives() -> None: + """Test vectorized comparison with primitive types.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + # Test integers + source = pa.array([1, 2, 3, 4]) + target = pa.array([1, 2, 5, 4]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, True, False] + + # Test strings + source = pa.array(["a", "b", "c"]) + target = pa.array(["a", "x", "c"]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # Test floats + source = pa.array([1.0, 2.5, 3.0]) + target = pa.array([1.0, 2.5, 3.1]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, True] + + +def test_vectorized_comparison_nulls() -> None: + """Test vectorized comparison handles nulls correctly.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + # null vs non-null = different + source = pa.array([1, None, 3]) + target = pa.array([1, 2, 3]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # non-null vs null = different + source = pa.array([1, 2, 3]) + target = pa.array([1, None, 3]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # null vs null = same (no update needed) + source = pa.array([1, None, 3]) + target = pa.array([1, None, 3]) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, False] + + +def test_vectorized_comparison_structs() -> None: + """Test vectorized comparison with nested struct types.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + struct_type = pa.struct([("x", pa.int32()), ("y", pa.string())]) + + # Same structs + source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False] + + # Different struct values + source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True] + + +def test_vectorized_comparison_nested_structs() -> None: + """Test vectorized comparison with deeply nested struct types.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + inner_struct = pa.struct([("val", pa.int32())]) + outer_struct = pa.struct([("inner", inner_struct), ("name", pa.string())]) + + source = pa.array( + [{"inner": {"val": 1}, "name": "a"}, {"inner": {"val": 2}, "name": "b"}], + type=outer_struct, + ) + target = pa.array( + [{"inner": {"val": 1}, "name": "a"}, {"inner": {"val": 3}, "name": "b"}], + type=outer_struct, + ) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True] + + +def test_vectorized_comparison_lists() -> None: + """Test vectorized comparison with list types (falls back to Python comparison).""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + list_type = pa.list_(pa.int32()) + + source = pa.array([[1, 2], [3, 4]], type=list_type) + target = pa.array([[1, 2], [3, 5]], type=list_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True] + + +def test_get_rows_to_update_no_non_key_cols() -> None: + """Test get_rows_to_update when all columns are key columns.""" + from pyiceberg.table.upsert_util import get_rows_to_update + + # All columns are key columns, so no non-key columns to compare + source = pa.Table.from_pydict({"id": [1, 2, 3]}) + target = pa.Table.from_pydict({"id": [1, 2, 3]}) + rows = get_rows_to_update(source, target, ["id"]) + assert len(rows) == 0 + + +def test_upsert_with_list_field(catalog: Catalog) -> None: + """Test upsert with list type as non-key column.""" + from pyiceberg.types import ListType + + identifier = "default.test_upsert_with_list_field" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "tags", + ListType(element_id=3, element_type=StringType(), element_required=False), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field("tags", pa.list_(pa.large_string()), nullable=True), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + {"id": 1, "tags": ["a", "b"]}, + {"id": 2, "tags": ["c"]}, + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + # Update with changed list + update_data = pa.Table.from_pylist( + [ + {"id": 1, "tags": ["a", "b"]}, # Same - no update + {"id": 2, "tags": ["c", "d"]}, # Changed - should update + {"id": 3, "tags": ["e"]}, # New - should insert + ], + schema=arrow_schema, + ) + + res = tbl.upsert(update_data, join_cols=["id"]) + assert res.rows_updated == 1 + assert res.rows_inserted == 1 + + +def test_vectorized_comparison_struct_level_nulls() -> None: + """Test vectorized comparison handles struct-level nulls correctly (not just field-level nulls).""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + struct_type = pa.struct([("x", pa.int32()), ("y", pa.string())]) + + # null struct vs non-null struct = different + source = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # non-null struct vs null struct = different + source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # null struct vs null struct = same (no update needed) + source = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + target = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False, False] + + +def test_vectorized_comparison_empty_struct_with_nulls() -> None: + """Test that empty structs with null values are compared correctly.""" + from pyiceberg.table.upsert_util import _compare_columns_vectorized + + # Empty struct type - edge case where only struct-level null handling matters + empty_struct_type = pa.struct([]) + + # null vs non-null empty struct = different + source = pa.array([{}, None, {}], type=empty_struct_type) + target = pa.array([{}, {}, {}], type=empty_struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, True, False] + + # null vs null empty struct = same + source = pa.array([None, None], type=empty_struct_type) + target = pa.array([None, None], type=empty_struct_type) + diff = _compare_columns_vectorized(source, target) + assert diff.to_pylist() == [False, False] + + +# ============================================================================ +# Tests for create_coarse_match_filter and _is_numeric_type +# ============================================================================ + + +@pytest.mark.parametrize( + "dtype,expected_numeric", + [ + (pa.int8(), True), + (pa.int16(), True), + (pa.int32(), True), + (pa.int64(), True), + (pa.uint8(), True), + (pa.uint16(), True), + (pa.uint32(), True), + (pa.uint64(), True), + (pa.float16(), True), + (pa.float32(), True), + (pa.float64(), True), + (pa.string(), False), + (pa.binary(), False), + (pa.date32(), False), + (pa.date64(), False), + (pa.timestamp("us"), False), + (pa.timestamp("ns"), False), + (pa.decimal128(10, 2), False), + (pa.decimal256(20, 4), False), + (pa.bool_(), False), + (pa.large_string(), False), + (pa.large_binary(), False), + ], +) +def test_is_numeric_type(dtype: pa.DataType, expected_numeric: bool) -> None: + """Test that _is_numeric_type correctly identifies all numeric types.""" + from pyiceberg.table.upsert_util import _is_numeric_type + + assert _is_numeric_type(dtype) == expected_numeric + + +# ============================================================================ +# Thresholding Tests (Small vs Large Datasets) +# ============================================================================ + + +def test_coarse_match_filter_small_dataset_uses_in_filter() -> None: + """Test that small datasets (< 10,000 unique keys) use In() filter.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dataset with 100 unique keys (well below threshold) + num_keys = 100 + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert num_keys < LARGE_FILTER_THRESHOLD + assert isinstance(result, In) + assert result.term.name == "id" + assert len(result.literals) == num_keys + + +def test_coarse_match_filter_threshold_boundary_uses_in_filter() -> None: + """Test that datasets at threshold - 1 (9,999 unique keys) still use In() filter.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dataset with exactly threshold - 1 unique keys + num_keys = LARGE_FILTER_THRESHOLD - 1 + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert isinstance(result, In) + assert result.term.name == "id" + assert len(result.literals) == num_keys + + +def test_coarse_match_filter_above_threshold_uses_optimized_filter() -> None: + """Test that datasets >= 10,000 unique keys use optimized filter strategy.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dense dataset (consecutive IDs) with exactly threshold unique keys + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Dense IDs should use range filter (And of GreaterThanOrEqual and LessThanOrEqual) + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == 0 + assert result.right.literal.value == num_keys - 1 + + +def test_coarse_match_filter_large_dataset() -> None: + """Test that large datasets (100,000 unique keys) use optimized filter.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create a dense dataset with 100,000 unique keys + num_keys = 100_000 + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert num_keys >= LARGE_FILTER_THRESHOLD + # Dense IDs should use range filter + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +# ============================================================================ +# Density Calculation Tests +# ============================================================================ + + +def test_coarse_match_filter_dense_ids_use_range_filter() -> None: + """Test that dense IDs (density > 10%) use range filter.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dense IDs: all values from 0 to N-1 (100% density) + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / (9999 - 0 + 1) = 100% + # Should use range filter + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +def test_coarse_match_filter_moderately_dense_ids_use_range_filter() -> None: + """Test that moderately dense IDs (50% density) use range filter.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create IDs: 0, 2, 4, 6, ... (every other number) - 50% density + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(0, num_keys * 2, 2)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / (19998 - 0 + 1) ~= 50% + # Should use range filter since density > 10% + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +def test_coarse_match_filter_sparse_ids_use_always_true() -> None: + """Test that sparse IDs (density <= 10%) use AlwaysTrue (full scan).""" + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create sparse IDs: values spread across a large range + # 10,000 values in range of ~110,000 = ~9% density + num_keys = LARGE_FILTER_THRESHOLD + ids = list(range(0, num_keys * 11, 11)) # 0, 11, 22, 33, ... + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density ~= 10000 / ((10000-1)*11 + 1) = 9.09% < 10% + # Should use AlwaysTrue (full scan) + assert isinstance(result, AlwaysTrue) + + +def test_coarse_match_filter_density_boundary_at_10_percent() -> None: + """Test exact 10% boundary density behavior.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create IDs at exactly ~10% density + # 10,000 values in range of 100,000 = exactly 10% + num_keys = LARGE_FILTER_THRESHOLD + # Generate 10,000 values in range [0, 99999] -> density = 10000/100000 = 10% + # Using every 10th value: 0, 10, 20, ... 99990 + ids = list(range(0, num_keys * 10, 10)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / ((num_keys-1)*10 + 1) = 10000 / 99991 ~= 10.001% + # Should use range filter since density > 10% (just barely) + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + +def test_coarse_match_filter_very_sparse_ids() -> None: + """Test that very sparse IDs (e.g., 1, 1M, 2M) use AlwaysTrue.""" + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create extremely sparse IDs + num_keys = LARGE_FILTER_THRESHOLD + # Values from 0 to (num_keys-1) * 1000, stepping by 1000 + ids = list(range(0, num_keys * 1000, 1000)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Density = 10000 / ((10000-1)*1000 + 1) ~= 0.1% + # Should use AlwaysTrue + assert isinstance(result, AlwaysTrue) + + +# ============================================================================ +# Edge Cases +# ============================================================================ + + +def test_coarse_match_filter_empty_dataset_returns_always_false() -> None: + """Test that empty dataset returns AlwaysFalse.""" + from pyiceberg.expressions import AlwaysFalse + + from pyiceberg.table.upsert_util import create_coarse_match_filter + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict({"id": [], "value": []}, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + assert isinstance(result, AlwaysFalse) + + +def test_coarse_match_filter_single_value_dataset() -> None: + """Test that single value dataset uses In() or EqualTo() with single value.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import create_coarse_match_filter + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict({"id": [42], "value": [1]}, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # PyIceberg may optimize In() with a single value to EqualTo() + if isinstance(result, In): + assert result.term.name == "id" + assert len(result.literals) == 1 + assert result.literals[0].value == 42 + elif isinstance(result, EqualTo): + assert result.term.name == "id" + assert result.literal.value == 42 + else: + pytest.fail(f"Expected In or EqualTo, got {type(result)}") + + +def test_coarse_match_filter_negative_numbers_range() -> None: + """Test that negative number IDs produce correct min/max range.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dense negative IDs: -10000 to -1 + num_keys = LARGE_FILTER_THRESHOLD + ids = list(range(-num_keys, 0)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Should use range filter with negative values + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == -num_keys # min + assert result.right.literal.value == -1 # max + + +def test_coarse_match_filter_mixed_sign_numbers_range() -> None: + """Test that mixed sign IDs (-500 to 500) produce correct range spanning zero.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create IDs spanning zero: -5000 to 4999 + num_keys = LARGE_FILTER_THRESHOLD + ids = list(range(-num_keys // 2, num_keys // 2)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Should use range filter spanning zero + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == -num_keys // 2 # min + assert result.right.literal.value == num_keys // 2 - 1 # max + + +def test_coarse_match_filter_float_range_filter() -> None: + """Test that float IDs use range filter correctly.""" + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dense float IDs + num_keys = LARGE_FILTER_THRESHOLD + ids = [float(i) for i in range(num_keys)] + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.float64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Should use range filter for float column + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + assert result.left.literal.value == 0.0 + assert result.right.literal.value == float(num_keys - 1) + + +def test_coarse_match_filter_non_numeric_column_skips_range_filter() -> None: + """Test that non-numeric column with >10k values returns AlwaysTrue.""" + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create string IDs (non-numeric) with many unique values + num_keys = LARGE_FILTER_THRESHOLD + ids = [f"id_{i:05d}" for i in range(num_keys)] + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["id"]) + + # Non-numeric column with large dataset should use AlwaysTrue + assert isinstance(result, AlwaysTrue) + + +# ============================================================================ +# Composite Key Tests +# ============================================================================ + + +def test_coarse_match_filter_composite_key_small_dataset() -> None: + """Test that composite key with small dataset uses And(In(), In()).""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import create_coarse_match_filter + + # Create a small dataset with composite key + data = { + "a": [1, 2, 3, 1, 2, 3], + "b": ["x", "x", "x", "y", "y", "y"], + "value": [1, 2, 3, 4, 5, 6], + } + schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["a", "b"]) + + # Should be And(In(a), In(b)) + assert isinstance(result, And) + # Check that both children are In() filters + assert "In" in str(result) + + +def test_coarse_match_filter_composite_key_large_numeric_column() -> None: + """Test composite key where one column has >10k unique numeric values.""" + from pyiceberg.expressions import GreaterThanOrEqual, In, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dataset with one large dense numeric column and one small column + num_keys = LARGE_FILTER_THRESHOLD + data = { + "a": list(range(num_keys)), # 10k unique dense values + "b": ["category_1"] * (num_keys // 2) + ["category_2"] * (num_keys // 2), # 2 unique values + "value": list(range(num_keys)), + } + schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["a", "b"]) + + # Should be And of filters for both columns + assert isinstance(result, And) + # Column 'a' (large, dense, numeric) should use range filter + # Column 'b' (small) should use In() + result_str = str(result) + assert "GreaterThanOrEqual" in result_str or "In" in result_str + + +def test_coarse_match_filter_composite_key_mixed_types() -> None: + """Test composite key with mixed numeric and string columns with large dataset.""" + from pyiceberg.expressions import In + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create dataset with large sparse numeric column and large string column + num_keys = LARGE_FILTER_THRESHOLD + # Sparse numeric IDs + ids = list(range(0, num_keys * 100, 100)) + # Many unique strings + strings = [f"str_{i}" for i in range(num_keys)] + data = { + "numeric_id": ids, + "string_id": strings, + "value": list(range(num_keys)), + } + schema = pa.schema([pa.field("numeric_id", pa.int64()), pa.field("string_id", pa.string()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + result = create_coarse_match_filter(table, ["numeric_id", "string_id"]) + + # Both columns have large unique values + # numeric_id is sparse (density < 10%), so should use In() + # string_id is non-numeric, so should use In() + assert isinstance(result, And) + + +# ============================================================================ +# Integration Test with Logging Verification +# ============================================================================ + + +def test_coarse_match_filter_logs_warning_for_full_scan(caplog: pytest.LogCaptureFixture) -> None: + """Verify logging when AlwaysTrue is used (indicates full scan will be performed).""" + import logging + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create table with >10k sparse IDs (will trigger AlwaysTrue) + num_keys = LARGE_FILTER_THRESHOLD + # Very sparse IDs + ids = list(range(0, num_keys * 1000, 1000)) + data = {"id": ids, "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + with caplog.at_level(logging.INFO, logger="pyiceberg.table.upsert_util"): + result = create_coarse_match_filter(table, ["id"]) + + # Verify AlwaysTrue is returned + assert isinstance(result, AlwaysTrue) + + # Verify log message about skipping filter + assert any("full scan will be performed" in record.message for record in caplog.records) + + +def test_coarse_match_filter_logs_range_filter_usage(caplog: pytest.LogCaptureFixture) -> None: + """Verify logging when range filter is used for dense IDs.""" + import logging + + from pyiceberg.expressions import GreaterThanOrEqual, LessThanOrEqual + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create table with >10k dense IDs (will use range filter) + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + with caplog.at_level(logging.INFO, logger="pyiceberg.table.upsert_util"): + result = create_coarse_match_filter(table, ["id"]) + + # Verify range filter is returned + assert isinstance(result, And) + assert isinstance(result.left, GreaterThanOrEqual) + assert isinstance(result.right, LessThanOrEqual) + + # Verify log message about using range filter + assert any("Using range filter" in record.message for record in caplog.records) + + +def test_coarse_match_filter_logs_large_dataset_detection(caplog: pytest.LogCaptureFixture) -> None: + """Verify logging when large dataset is detected.""" + import logging + + from pyiceberg.table.upsert_util import LARGE_FILTER_THRESHOLD, create_coarse_match_filter + + # Create table with exactly LARGE_FILTER_THRESHOLD keys + num_keys = LARGE_FILTER_THRESHOLD + data = {"id": list(range(num_keys)), "value": list(range(num_keys))} + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + table = pa.Table.from_pydict(data, schema=schema) + + with caplog.at_level(logging.INFO, logger="pyiceberg.table.upsert_util"): + create_coarse_match_filter(table, ["id"]) + + # Verify log message about large dataset detection + assert any("Large dataset detected" in record.message for record in caplog.records) + assert any(f"{LARGE_FILTER_THRESHOLD}" in record.message for record in caplog.records)