Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/versus/comparison/_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def key_part(identifier: str) -> str:
{h.sql_literal(identifier)} AS table_name,
{select_by}
FROM
{h.ident(handle_left.name)} AS left_tbl
ANTI JOIN {h.ident(handle_right.name)} AS right_tbl
{h.table_ref(handle_left)} AS left_tbl
ANTI JOIN {h.table_ref(handle_right)} AS right_tbl
ON {condition}
"""

Expand Down
4 changes: 2 additions & 2 deletions python/versus/comparison/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,10 @@ def compare(
by_columns = h.normalize_column_list(by, "by", allow_empty=False)
connection_supplied = connection is not None
handles = {
clean_ids[0]: h.register_input_view(
clean_ids[0]: h.build_table_handle(
conn, table_a, clean_ids[0], connection_supplied=connection_supplied
),
clean_ids[1]: h.register_input_view(
clean_ids[1]: h.build_table_handle(
conn, table_b, clean_ids[1], connection_supplied=connection_supplied
),
}
Expand Down
194 changes: 138 additions & 56 deletions python/versus/comparison/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,20 @@
from ._exceptions import ComparisonError

if TYPE_CHECKING: # pragma: no cover
import pandas
import polars

from ._core import Comparison

try:
from typing import TypeAlias
except ImportError: # pragma: no cover - Python < 3.10
from typing_extensions import TypeAlias

_Input: TypeAlias = Union[
duckdb.DuckDBPyRelation, "pandas.DataFrame", "polars.DataFrame"
]


# --------------- Data structures
@dataclass
Expand All @@ -34,6 +46,9 @@ class _TableHandle:
relation: duckdb.DuckDBPyRelation
columns: List[str]
types: Dict[str, str]
source_sql: str
source_is_identifier: bool
row_count: int

def __getattr__(self, name: str) -> Any:
return getattr(self.relation, name)
Expand Down Expand Up @@ -200,7 +215,7 @@ def assert_unique_by(
{cols},
COUNT(*) AS n
FROM
{ident(handle.name)} AS t
{table_ref(handle)} AS t
GROUP BY
{cols}
HAVING
Expand Down Expand Up @@ -300,87 +315,146 @@ def assert_column_allowed(comparison: "Comparison", column: str, func: str) -> N


# --------------- Input registration and metadata
def register_input_view(
def build_table_handle(
conn: VersusConn,
source: Any,
source: _Input,
label: str,
*,
connection_supplied: bool,
) -> _TableHandle:
name = f"__versus_{label}_{uuid.uuid4().hex}"
display = "relation"
base_name = None
relation_source = False
if isinstance(source, duckdb.DuckDBPyRelation):
relation_source = True
base_name = f"{name}_base"
validate_columns(source.columns, label)
source.to_view(base_name, replace=True)
source_ref = ident(base_name)
source_sql = source.sql_query()
display = getattr(source, "alias", "relation")
elif isinstance(source, str):
assert_relation_connection(conn, source, label, connection_supplied)
try:
columns, types = describe_source(conn, source_sql, is_identifier=False)
except duckdb.Error as exc:
raise_relation_connection_error(label, connection_supplied, exc)
row_count = resolve_row_count(conn, source, source_sql, is_identifier=False)
relation = conn.sql(source_sql)
return _TableHandle(
name=name,
display=display,
relation=relation,
columns=columns,
types=types,
source_sql=source_sql,
source_is_identifier=False,
row_count=row_count,
)
if isinstance(source, str):
raise ComparisonError(
"String inputs are not supported. Pass a DuckDB relation or pandas/polars "
"DataFrame."
)
else:
base_name = f"{name}_base"
source_columns = getattr(source, "columns", None)
if source_columns is not None:
validate_columns(list(source_columns), label)
try:
conn.register(base_name, source)
except Exception as exc:
raise ComparisonError(
"Inputs must be DuckDB relations or pandas/polars DataFrames."
) from exc
source_ref = ident(base_name)
display = type(source).__name__

source_columns = getattr(source, "columns", None)
if source_columns is not None:
validate_columns(list(source_columns), label)
try:
conn.execute(
f"CREATE OR REPLACE TEMP VIEW {ident(name)} AS SELECT * FROM {source_ref}"
)
except duckdb.Error as exc:
if relation_source and base_name is not None and base_name in str(exc):
arg_name = f"table_{label}"
if connection_supplied:
hint = (
f"`{arg_name}` appears to be bound to a different DuckDB "
"connection than the one passed to `compare()`. Pass the same "
"connection that created the relations via `connection=...`."
)
else:
hint = (
f"`{arg_name}` appears to be bound to a non-default DuckDB "
"connection. Pass that connection to `compare()` via "
"`connection=...`."
)
raise ComparisonError(hint) from exc
raise
if base_name is not None:
conn.versus.views.append(base_name)
conn.register(name, source)
except Exception as exc:
raise ComparisonError(
"Inputs must be DuckDB relations or pandas/polars DataFrames."
) from exc
conn.versus.views.append(name)

columns, types = describe_view(conn, name)
source_sql = name
columns, types = describe_source(conn, source_sql, is_identifier=True)
row_count = resolve_row_count(conn, source, source_sql, is_identifier=True)
relation = conn.table(name)
return _TableHandle(
name=name,
display=display,
display=type(source).__name__,
relation=relation,
columns=columns,
types=types,
source_sql=source_sql,
source_is_identifier=True,
row_count=row_count,
)


def describe_view(conn: VersusConn, name: str) -> Tuple[List[str], Dict[str, str]]:
rel = run_sql(conn, f"DESCRIBE SELECT * FROM {ident(name)}")
def describe_source(
conn: VersusConn,
source_sql: str,
*,
is_identifier: bool,
) -> Tuple[List[str], Dict[str, str]]:
source_ref = source_ref_for_sql(source_sql, is_identifier)
rel = run_sql(conn, f"DESCRIBE SELECT * FROM {source_ref}")
rows = rel.fetchall()
columns = [row[0] for row in rows]
types = {row[0]: row[1] for row in rows}
return columns, types


def source_ref_for_sql(source_sql: str, is_identifier: bool) -> str:
return ident(source_sql) if is_identifier else f"({source_sql})"


def resolve_row_count(
conn: VersusConn,
source: _Input,
source_sql: str,
*,
is_identifier: bool,
) -> int:
frame_row_count = row_count_from_frame(source)
if frame_row_count is not None:
return frame_row_count
source_ref = source_ref_for_sql(source_sql, is_identifier)
row = run_sql(conn, f"SELECT COUNT(*) FROM {source_ref}").fetchone()
assert row is not None and isinstance(row[0], int)
return row[0]


def row_count_from_frame(source: _Input) -> Optional[int]:
module = type(source).__module__
if module.startswith("pandas"):
return int(cast("pandas.DataFrame", source).shape[0])
if module.startswith("polars"):
return int(cast("polars.DataFrame", source).height)
return None


def raise_relation_connection_error(
label: str,
connection_supplied: bool,
exc: Exception,
) -> None:
arg_name = f"table_{label}"
if connection_supplied:
hint = (
f"`{arg_name}` appears to be bound to a different DuckDB "
"connection than the one passed to `compare()`. Pass the same "
"connection that created the relations via `connection=...`."
)
else:
hint = (
f"`{arg_name}` appears to be bound to a non-default DuckDB "
"connection. Pass that connection to `compare()` via "
"`connection=...`."
)
raise ComparisonError(hint) from exc


def assert_relation_connection(
conn: VersusConn,
relation: duckdb.DuckDBPyRelation,
label: str,
connection_supplied: bool,
) -> None:
probe_name = f"__versus_probe_{uuid.uuid4().hex}"
try:
conn.register(probe_name, relation)
except Exception as exc:
raise_relation_connection_error(label, connection_supplied, exc)
else:
conn.unregister(probe_name)


# --------------- SQL builder helpers
def ident(name: str) -> str:
escaped = name.replace('"', '""')
Expand All @@ -391,6 +465,12 @@ def col(alias: str, column: str) -> str:
return f"{alias}.{ident(column)}"


def table_ref(handle: _TableHandle) -> str:
if handle.source_is_identifier:
return ident(handle.source_sql)
return f"({handle.source_sql})"


def select_cols(columns: Sequence[str], alias: Optional[str] = None) -> str:
if not columns:
raise ComparisonError("Column list must be non-empty")
Expand All @@ -414,8 +494,8 @@ def inputs_join_sql(
) -> str:
join_condition_sql = join_condition(by_columns, "a", "b")
return (
f"{ident(handles[table_id[0]].name)} AS a\n"
f" INNER JOIN {ident(handles[table_id[1]].name)} AS b\n"
f"{table_ref(handles[table_id[0]])} AS a\n"
f" INNER JOIN {table_ref(handles[table_id[1]])} AS b\n"
f" ON {join_condition_sql}"
)

Expand Down Expand Up @@ -484,7 +564,7 @@ def fetch_rows_by_keys(
{select_cols_sql}
FROM
({key_sql}) AS keys
JOIN {ident(comparison._handles[table].name)} AS base
JOIN {table_ref(comparison._handles[table])} AS base
ON {join_condition_sql}
"""
return run_sql(comparison.connection, sql)
Expand Down Expand Up @@ -568,6 +648,8 @@ def build_rows_relation(


def table_count(relation: Union[duckdb.DuckDBPyRelation, _TableHandle]) -> int:
if isinstance(relation, _TableHandle):
return relation.row_count
row = relation.count("*").fetchall()[0]
assert isinstance(row[0], int)
return row[0]
Expand All @@ -580,10 +662,10 @@ def select_zero_from_table(
) -> duckdb.DuckDBPyRelation:
handle = comparison._handles[table]
if columns is None:
sql = f"SELECT * FROM {ident(handle.name)} LIMIT 0"
sql = f"SELECT * FROM {table_ref(handle)} LIMIT 0"
return run_sql(comparison.connection, sql)
if not columns:
raise ComparisonError("Column list must be non-empty")
select_cols_sql = select_cols(columns)
sql = f"SELECT {select_cols_sql} FROM {ident(handle.name)} LIMIT 0"
sql = f"SELECT {select_cols_sql} FROM {table_ref(handle)} LIMIT 0"
return run_sql(comparison.connection, sql)
4 changes: 2 additions & 2 deletions python/versus/comparison/_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def slice_unmatched_both(comparison: "Comparison") -> duckdb.DuckDBPyRelation:

def select_for(table_name: str) -> str:
unmatched_keys_sql = build_unmatched_keys_sql(comparison, table_name)
base_table = comparison._handles[table_name].name
base_table = comparison._handles[table_name]
return f"""
SELECT
{h.sql_literal(table_name)} AS table_name,
{select_cols}
FROM
{h.ident(base_table)} AS base
{h.table_ref(base_table)} AS base
JOIN ({unmatched_keys_sql}) AS keys
ON {join_condition}
"""
Expand Down
8 changes: 4 additions & 4 deletions python/versus/comparison/_value_diffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def _value_diffs_with_diff_table(
{", ".join(select_cols)}
FROM
({key_sql}) AS keys
JOIN {h.ident(comparison._handles[table_a].name)} AS a
JOIN {h.table_ref(comparison._handles[table_a])} AS a
ON {join_a}
JOIN {h.ident(comparison._handles[table_b].name)} AS b
JOIN {h.table_ref(comparison._handles[table_b])} AS b
ON {join_b}
"""
return h.run_sql(comparison.connection, sql)
Expand Down Expand Up @@ -112,9 +112,9 @@ def stack_value_diffs_sql(
{", ".join(select_parts)}
FROM
({key_sql}) AS keys
JOIN {h.ident(comparison._handles[table_a].name)} AS a
JOIN {h.table_ref(comparison._handles[table_a])} AS a
ON {join_a}
JOIN {h.ident(comparison._handles[table_b].name)} AS b
JOIN {h.table_ref(comparison._handles[table_b])} AS b
ON {join_b}
"""

Expand Down
8 changes: 4 additions & 4 deletions python/versus/comparison/_weave.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def _weave_diffs_wide_with_keys(
{", ".join(select_parts)}
FROM
({keys}) AS keys
JOIN {h.ident(comparison._handles[table_a].name)} AS a
JOIN {h.table_ref(comparison._handles[table_a])} AS a
ON {join_a}
JOIN {h.ident(comparison._handles[table_b].name)} AS b
JOIN {h.table_ref(comparison._handles[table_b])} AS b
ON {join_b}
"""
return h.run_sql(comparison.connection, sql)
Expand Down Expand Up @@ -153,7 +153,7 @@ def _weave_diffs_long_with_keys(
{select_cols_a}
FROM
keys
JOIN {h.ident(comparison._handles[table_a].name)} AS a
JOIN {h.table_ref(comparison._handles[table_a])} AS a
ON {join_a}
UNION ALL
SELECT
Expand All @@ -162,7 +162,7 @@ def _weave_diffs_long_with_keys(
{select_cols_b}
FROM
keys
JOIN {h.ident(comparison._handles[table_b].name)} AS b
JOIN {h.table_ref(comparison._handles[table_b])} AS b
ON {join_b}
) AS stacked
ORDER BY
Expand Down