diff --git a/python/versus/comparison/_compute.py b/python/versus/comparison/_compute.py index 49214e0..7efb59c 100644 --- a/python/versus/comparison/_compute.py +++ b/python/versus/comparison/_compute.py @@ -263,8 +263,8 @@ def key_part(identifier: str) -> str: {h.sql_literal(identifier)} AS table_name, {select_by} FROM - {h.ident(handle_left.name)} AS left_tbl - ANTI JOIN {h.ident(handle_right.name)} AS right_tbl + {h.table_ref(handle_left)} AS left_tbl + ANTI JOIN {h.table_ref(handle_right)} AS right_tbl ON {condition} """ diff --git a/python/versus/comparison/_core.py b/python/versus/comparison/_core.py index 3b12f24..4d2d506 100644 --- a/python/versus/comparison/_core.py +++ b/python/versus/comparison/_core.py @@ -509,10 +509,10 @@ def compare( by_columns = h.normalize_column_list(by, "by", allow_empty=False) connection_supplied = connection is not None handles = { - clean_ids[0]: h.register_input_view( + clean_ids[0]: h.build_table_handle( conn, table_a, clean_ids[0], connection_supplied=connection_supplied ), - clean_ids[1]: h.register_input_view( + clean_ids[1]: h.build_table_handle( conn, table_b, clean_ids[1], connection_supplied=connection_supplied ), } diff --git a/python/versus/comparison/_helpers.py b/python/versus/comparison/_helpers.py index 44c8f28..9705ae1 100644 --- a/python/versus/comparison/_helpers.py +++ b/python/versus/comparison/_helpers.py @@ -23,8 +23,20 @@ from ._exceptions import ComparisonError if TYPE_CHECKING: # pragma: no cover + import pandas + import polars + from ._core import Comparison +try: + from typing import TypeAlias +except ImportError: # pragma: no cover - Python < 3.10 + from typing_extensions import TypeAlias + +_Input: TypeAlias = Union[ + duckdb.DuckDBPyRelation, "pandas.DataFrame", "polars.DataFrame" +] + # --------------- Data structures @dataclass @@ -34,6 +46,9 @@ class _TableHandle: relation: duckdb.DuckDBPyRelation columns: List[str] types: Dict[str, str] + source_sql: str + source_is_identifier: bool + row_count: int def __getattr__(self, name: str) -> Any: return getattr(self.relation, name) @@ -200,7 +215,7 @@ def assert_unique_by( {cols}, COUNT(*) AS n FROM - {ident(handle.name)} AS t + {table_ref(handle)} AS t GROUP BY {cols} HAVING @@ -300,87 +315,146 @@ def assert_column_allowed(comparison: "Comparison", column: str, func: str) -> N # --------------- Input registration and metadata -def register_input_view( +def build_table_handle( conn: VersusConn, - source: Any, + source: _Input, label: str, *, connection_supplied: bool, ) -> _TableHandle: name = f"__versus_{label}_{uuid.uuid4().hex}" display = "relation" - base_name = None - relation_source = False if isinstance(source, duckdb.DuckDBPyRelation): - relation_source = True - base_name = f"{name}_base" validate_columns(source.columns, label) - source.to_view(base_name, replace=True) - source_ref = ident(base_name) + source_sql = source.sql_query() display = getattr(source, "alias", "relation") - elif isinstance(source, str): + assert_relation_connection(conn, source, label, connection_supplied) + try: + columns, types = describe_source(conn, source_sql, is_identifier=False) + except duckdb.Error as exc: + raise_relation_connection_error(label, connection_supplied, exc) + row_count = resolve_row_count(conn, source, source_sql, is_identifier=False) + relation = conn.sql(source_sql) + return _TableHandle( + name=name, + display=display, + relation=relation, + columns=columns, + types=types, + source_sql=source_sql, + source_is_identifier=False, + row_count=row_count, + ) + if isinstance(source, str): raise ComparisonError( "String inputs are not supported. Pass a DuckDB relation or pandas/polars " "DataFrame." ) - else: - base_name = f"{name}_base" - source_columns = getattr(source, "columns", None) - if source_columns is not None: - validate_columns(list(source_columns), label) - try: - conn.register(base_name, source) - except Exception as exc: - raise ComparisonError( - "Inputs must be DuckDB relations or pandas/polars DataFrames." - ) from exc - source_ref = ident(base_name) - display = type(source).__name__ - + source_columns = getattr(source, "columns", None) + if source_columns is not None: + validate_columns(list(source_columns), label) try: - conn.execute( - f"CREATE OR REPLACE TEMP VIEW {ident(name)} AS SELECT * FROM {source_ref}" - ) - except duckdb.Error as exc: - if relation_source and base_name is not None and base_name in str(exc): - arg_name = f"table_{label}" - if connection_supplied: - hint = ( - f"`{arg_name}` appears to be bound to a different DuckDB " - "connection than the one passed to `compare()`. Pass the same " - "connection that created the relations via `connection=...`." - ) - else: - hint = ( - f"`{arg_name}` appears to be bound to a non-default DuckDB " - "connection. Pass that connection to `compare()` via " - "`connection=...`." - ) - raise ComparisonError(hint) from exc - raise - if base_name is not None: - conn.versus.views.append(base_name) + conn.register(name, source) + except Exception as exc: + raise ComparisonError( + "Inputs must be DuckDB relations or pandas/polars DataFrames." + ) from exc conn.versus.views.append(name) - - columns, types = describe_view(conn, name) + source_sql = name + columns, types = describe_source(conn, source_sql, is_identifier=True) + row_count = resolve_row_count(conn, source, source_sql, is_identifier=True) relation = conn.table(name) return _TableHandle( name=name, - display=display, + display=type(source).__name__, relation=relation, columns=columns, types=types, + source_sql=source_sql, + source_is_identifier=True, + row_count=row_count, ) -def describe_view(conn: VersusConn, name: str) -> Tuple[List[str], Dict[str, str]]: - rel = run_sql(conn, f"DESCRIBE SELECT * FROM {ident(name)}") +def describe_source( + conn: VersusConn, + source_sql: str, + *, + is_identifier: bool, +) -> Tuple[List[str], Dict[str, str]]: + source_ref = source_ref_for_sql(source_sql, is_identifier) + rel = run_sql(conn, f"DESCRIBE SELECT * FROM {source_ref}") rows = rel.fetchall() columns = [row[0] for row in rows] types = {row[0]: row[1] for row in rows} return columns, types +def source_ref_for_sql(source_sql: str, is_identifier: bool) -> str: + return ident(source_sql) if is_identifier else f"({source_sql})" + + +def resolve_row_count( + conn: VersusConn, + source: _Input, + source_sql: str, + *, + is_identifier: bool, +) -> int: + frame_row_count = row_count_from_frame(source) + if frame_row_count is not None: + return frame_row_count + source_ref = source_ref_for_sql(source_sql, is_identifier) + row = run_sql(conn, f"SELECT COUNT(*) FROM {source_ref}").fetchone() + assert row is not None and isinstance(row[0], int) + return row[0] + + +def row_count_from_frame(source: _Input) -> Optional[int]: + module = type(source).__module__ + if module.startswith("pandas"): + return int(cast("pandas.DataFrame", source).shape[0]) + if module.startswith("polars"): + return int(cast("polars.DataFrame", source).height) + return None + + +def raise_relation_connection_error( + label: str, + connection_supplied: bool, + exc: Exception, +) -> None: + arg_name = f"table_{label}" + if connection_supplied: + hint = ( + f"`{arg_name}` appears to be bound to a different DuckDB " + "connection than the one passed to `compare()`. Pass the same " + "connection that created the relations via `connection=...`." + ) + else: + hint = ( + f"`{arg_name}` appears to be bound to a non-default DuckDB " + "connection. Pass that connection to `compare()` via " + "`connection=...`." + ) + raise ComparisonError(hint) from exc + + +def assert_relation_connection( + conn: VersusConn, + relation: duckdb.DuckDBPyRelation, + label: str, + connection_supplied: bool, +) -> None: + probe_name = f"__versus_probe_{uuid.uuid4().hex}" + try: + conn.register(probe_name, relation) + except Exception as exc: + raise_relation_connection_error(label, connection_supplied, exc) + else: + conn.unregister(probe_name) + + # --------------- SQL builder helpers def ident(name: str) -> str: escaped = name.replace('"', '""') @@ -391,6 +465,12 @@ def col(alias: str, column: str) -> str: return f"{alias}.{ident(column)}" +def table_ref(handle: _TableHandle) -> str: + if handle.source_is_identifier: + return ident(handle.source_sql) + return f"({handle.source_sql})" + + def select_cols(columns: Sequence[str], alias: Optional[str] = None) -> str: if not columns: raise ComparisonError("Column list must be non-empty") @@ -414,8 +494,8 @@ def inputs_join_sql( ) -> str: join_condition_sql = join_condition(by_columns, "a", "b") return ( - f"{ident(handles[table_id[0]].name)} AS a\n" - f" INNER JOIN {ident(handles[table_id[1]].name)} AS b\n" + f"{table_ref(handles[table_id[0]])} AS a\n" + f" INNER JOIN {table_ref(handles[table_id[1]])} AS b\n" f" ON {join_condition_sql}" ) @@ -484,7 +564,7 @@ def fetch_rows_by_keys( {select_cols_sql} FROM ({key_sql}) AS keys - JOIN {ident(comparison._handles[table].name)} AS base + JOIN {table_ref(comparison._handles[table])} AS base ON {join_condition_sql} """ return run_sql(comparison.connection, sql) @@ -568,6 +648,8 @@ def build_rows_relation( def table_count(relation: Union[duckdb.DuckDBPyRelation, _TableHandle]) -> int: + if isinstance(relation, _TableHandle): + return relation.row_count row = relation.count("*").fetchall()[0] assert isinstance(row[0], int) return row[0] @@ -580,10 +662,10 @@ def select_zero_from_table( ) -> duckdb.DuckDBPyRelation: handle = comparison._handles[table] if columns is None: - sql = f"SELECT * FROM {ident(handle.name)} LIMIT 0" + sql = f"SELECT * FROM {table_ref(handle)} LIMIT 0" return run_sql(comparison.connection, sql) if not columns: raise ComparisonError("Column list must be non-empty") select_cols_sql = select_cols(columns) - sql = f"SELECT {select_cols_sql} FROM {ident(handle.name)} LIMIT 0" + sql = f"SELECT {select_cols_sql} FROM {table_ref(handle)} LIMIT 0" return run_sql(comparison.connection, sql) diff --git a/python/versus/comparison/_slices.py b/python/versus/comparison/_slices.py index bf9580c..74e443a 100644 --- a/python/versus/comparison/_slices.py +++ b/python/versus/comparison/_slices.py @@ -85,13 +85,13 @@ def slice_unmatched_both(comparison: "Comparison") -> duckdb.DuckDBPyRelation: def select_for(table_name: str) -> str: unmatched_keys_sql = build_unmatched_keys_sql(comparison, table_name) - base_table = comparison._handles[table_name].name + base_table = comparison._handles[table_name] return f""" SELECT {h.sql_literal(table_name)} AS table_name, {select_cols} FROM - {h.ident(base_table)} AS base + {h.table_ref(base_table)} AS base JOIN ({unmatched_keys_sql}) AS keys ON {join_condition} """ diff --git a/python/versus/comparison/_value_diffs.py b/python/versus/comparison/_value_diffs.py index 61f8f2c..b44756c 100644 --- a/python/versus/comparison/_value_diffs.py +++ b/python/versus/comparison/_value_diffs.py @@ -60,9 +60,9 @@ def _value_diffs_with_diff_table( {", ".join(select_cols)} FROM ({key_sql}) AS keys - JOIN {h.ident(comparison._handles[table_a].name)} AS a + JOIN {h.table_ref(comparison._handles[table_a])} AS a ON {join_a} - JOIN {h.ident(comparison._handles[table_b].name)} AS b + JOIN {h.table_ref(comparison._handles[table_b])} AS b ON {join_b} """ return h.run_sql(comparison.connection, sql) @@ -112,9 +112,9 @@ def stack_value_diffs_sql( {", ".join(select_parts)} FROM ({key_sql}) AS keys - JOIN {h.ident(comparison._handles[table_a].name)} AS a + JOIN {h.table_ref(comparison._handles[table_a])} AS a ON {join_a} - JOIN {h.ident(comparison._handles[table_b].name)} AS b + JOIN {h.table_ref(comparison._handles[table_b])} AS b ON {join_b} """ diff --git a/python/versus/comparison/_weave.py b/python/versus/comparison/_weave.py index 4598106..9135194 100644 --- a/python/versus/comparison/_weave.py +++ b/python/versus/comparison/_weave.py @@ -92,9 +92,9 @@ def _weave_diffs_wide_with_keys( {", ".join(select_parts)} FROM ({keys}) AS keys - JOIN {h.ident(comparison._handles[table_a].name)} AS a + JOIN {h.table_ref(comparison._handles[table_a])} AS a ON {join_a} - JOIN {h.ident(comparison._handles[table_b].name)} AS b + JOIN {h.table_ref(comparison._handles[table_b])} AS b ON {join_b} """ return h.run_sql(comparison.connection, sql) @@ -153,7 +153,7 @@ def _weave_diffs_long_with_keys( {select_cols_a} FROM keys - JOIN {h.ident(comparison._handles[table_a].name)} AS a + JOIN {h.table_ref(comparison._handles[table_a])} AS a ON {join_a} UNION ALL SELECT @@ -162,7 +162,7 @@ def _weave_diffs_long_with_keys( {select_cols_b} FROM keys - JOIN {h.ident(comparison._handles[table_b].name)} AS b + JOIN {h.table_ref(comparison._handles[table_b])} AS b ON {join_b} ) AS stacked ORDER BY