Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/versus/comparison/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._core import Comparison, compare
from ._exceptions import ComparisonError
from .api import compare
from .comparison import Comparison

__all__ = [
"Comparison",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,35 @@

import duckdb

from . import _helpers as h
from . import _relations as r
from . import _sql as q
from . import _summary as s
from . import _types as t


def build_tables_frame(
conn: h.VersusConn,
handles: Mapping[str, h._TableHandle],
conn: t.VersusConn,
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
materialize: bool,
) -> duckdb.DuckDBPyRelation:
def row_for(identifier: str) -> Tuple[str, int, int]:
handle = handles[identifier]
return identifier, h.table_count(handle), len(handle.columns)
return identifier, r.table_count(handle), len(handle.columns)

rows = [row_for(identifier) for identifier in table_id]
schema = [
("table_name", "VARCHAR"),
("nrow", "BIGINT"),
("ncol", "BIGINT"),
]
return h.build_rows_relation(conn, rows, schema, materialize)
return s.build_rows_relation(conn, rows, schema, materialize)


def build_by_frame(
conn: h.VersusConn,
conn: t.VersusConn,
by_columns: List[str],
handles: Mapping[str, h._TableHandle],
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
materialize: bool,
) -> duckdb.DuckDBPyRelation:
Expand All @@ -47,12 +50,12 @@ def build_by_frame(
(f"type_{first}", "VARCHAR"),
(f"type_{second}", "VARCHAR"),
]
return h.build_rows_relation(conn, rows, schema, materialize)
return s.build_rows_relation(conn, rows, schema, materialize)


def build_unmatched_cols(
conn: h.VersusConn,
handles: Mapping[str, h._TableHandle],
conn: t.VersusConn,
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
materialize: bool,
) -> duckdb.DuckDBPyRelation:
Expand All @@ -71,17 +74,17 @@ def build_unmatched_cols(
("column", "VARCHAR"),
("type", "VARCHAR"),
]
return h.build_rows_relation(conn, rows, schema, materialize)
return s.build_rows_relation(conn, rows, schema, materialize)


def build_intersection_frame(
value_columns: List[str],
handles: Mapping[str, h._TableHandle],
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
by_columns: List[str],
allow_both_na: bool,
diff_table: Optional[duckdb.DuckDBPyRelation],
conn: h.VersusConn,
conn: t.VersusConn,
materialize: bool,
) -> Tuple[duckdb.DuckDBPyRelation, Optional[Dict[str, int]]]:
if diff_table is None:
Expand All @@ -100,7 +103,7 @@ def build_intersection_frame(


def _build_empty_intersection_relation(
conn: h.VersusConn,
conn: t.VersusConn,
table_id: Tuple[str, str],
materialize: bool,
) -> Tuple[duckdb.DuckDBPyRelation, Optional[Dict[str, int]]]:
Expand All @@ -111,16 +114,16 @@ def _build_empty_intersection_relation(
(f"type_{first}", "VARCHAR"),
(f"type_{second}", "VARCHAR"),
]
relation = h.build_rows_relation(conn, [], schema, materialize)
relation = s.build_rows_relation(conn, [], schema, materialize)
return relation, {} if materialize else None


def _build_intersection_frame_with_table(
value_columns: List[str],
handles: Mapping[str, h._TableHandle],
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
diff_table: duckdb.DuckDBPyRelation,
conn: h.VersusConn,
conn: t.VersusConn,
materialize: bool,
) -> Tuple[duckdb.DuckDBPyRelation, Optional[Dict[str, int]]]:
first, second = table_id
Expand All @@ -131,18 +134,18 @@ def diff_alias(column: str) -> str:
return f"n_diffs_{column}"

count_columns = ",\n ".join(
f"COUNT(*) FILTER (WHERE diffs.{h.ident(column)}) "
f"AS {h.ident(diff_alias(column))}"
f"COUNT(*) FILTER (WHERE diffs.{q.ident(column)}) "
f"AS {q.ident(diff_alias(column))}"
for column in value_columns
)

def select_for(column: str) -> str:
return f"""
SELECT
{h.sql_literal(column)} AS {h.ident("column")},
counts.{h.ident(diff_alias(column))} AS {h.ident("n_diffs")},
{h.sql_literal(handles[first].types[column])} AS {h.ident(f"type_{first}")},
{h.sql_literal(handles[second].types[column])} AS {h.ident(f"type_{second}")}
{q.sql_literal(column)} AS {q.ident("column")},
counts.{q.ident(diff_alias(column))} AS {q.ident("n_diffs")},
{q.sql_literal(handles[first].types[column])} AS {q.ident(f"type_{first}")},
{q.sql_literal(handles[second].types[column])} AS {q.ident(f"type_{second}")}
FROM
counts
"""
Expand All @@ -157,42 +160,42 @@ def select_for(column: str) -> str:
)
{" UNION ALL ".join(select_for(column) for column in value_columns)}
"""
relation = h.finalize_relation(conn, sql, materialize)
relation = s.finalize_relation(conn, sql, materialize)
if not materialize:
return relation, None
return relation, h.diff_lookup_from_intersection(relation)
return relation, r.diff_lookup_from_intersection(relation)


def _build_intersection_frame_inline(
value_columns: List[str],
handles: Mapping[str, h._TableHandle],
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
by_columns: List[str],
allow_both_na: bool,
conn: h.VersusConn,
conn: t.VersusConn,
materialize: bool,
) -> Tuple[duckdb.DuckDBPyRelation, Optional[Dict[str, int]]]:
if not value_columns:
return _build_empty_intersection_relation(conn, table_id, materialize)
first, second = table_id
join_sql = h.inputs_join_sql(handles, table_id, by_columns)
join_sql = q.inputs_join_sql(handles, table_id, by_columns)

def diff_alias(column: str) -> str:
return f"n_diffs_{column}"

count_columns = ",\n ".join(
f"COUNT(*) FILTER (WHERE {h.diff_predicate(column, allow_both_na, 'a', 'b')}) "
f"AS {h.ident(diff_alias(column))}"
f"COUNT(*) FILTER (WHERE {q.diff_predicate(column, allow_both_na, 'a', 'b')}) "
f"AS {q.ident(diff_alias(column))}"
for column in value_columns
)

def select_for(column: str) -> str:
return f"""
SELECT
{h.sql_literal(column)} AS {h.ident("column")},
counts.{h.ident(diff_alias(column))} AS {h.ident("n_diffs")},
{h.sql_literal(handles[first].types[column])} AS {h.ident(f"type_{first}")},
{h.sql_literal(handles[second].types[column])} AS {h.ident(f"type_{second}")}
{q.sql_literal(column)} AS {q.ident("column")},
counts.{q.ident(diff_alias(column))} AS {q.ident("n_diffs")},
{q.sql_literal(handles[first].types[column])} AS {q.ident(f"type_{first}")},
{q.sql_literal(handles[second].types[column])} AS {q.ident(f"type_{second}")}
FROM
counts
"""
Expand All @@ -206,31 +209,31 @@ def select_for(column: str) -> str:
)
{" UNION ALL ".join(select_for(column) for column in value_columns)}
"""
relation = h.finalize_relation(conn, sql, materialize)
relation = s.finalize_relation(conn, sql, materialize)
if not materialize:
return relation, None
return relation, h.diff_lookup_from_intersection(relation)
return relation, r.diff_lookup_from_intersection(relation)


def compute_diff_table(
conn: h.VersusConn,
handles: Mapping[str, h._TableHandle],
conn: t.VersusConn,
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
by_columns: List[str],
value_columns: List[str],
allow_both_na: bool,
) -> duckdb.DuckDBPyRelation:
if not value_columns:
schema = [(column, handles[table_id[0]].types[column]) for column in by_columns]
return h.build_rows_relation(conn, [], schema, materialize=True)
join_sql = h.inputs_join_sql(handles, table_id, by_columns)
select_by = h.select_cols(by_columns, alias="a")
return s.build_rows_relation(conn, [], schema, materialize=True)
join_sql = q.inputs_join_sql(handles, table_id, by_columns)
select_by = q.select_cols(by_columns, alias="a")
diff_expressions = [
(column, h.diff_predicate(column, allow_both_na, "a", "b"))
(column, q.diff_predicate(column, allow_both_na, "a", "b"))
for column in value_columns
]
diff_flags = ",\n ".join(
f"{expression} AS {h.ident(column)}" for column, expression in diff_expressions
f"{expression} AS {q.ident(column)}" for column, expression in diff_expressions
)
predicate = " OR ".join(expression for _, expression in diff_expressions)
sql = f"""
Expand All @@ -242,12 +245,12 @@ def compute_diff_table(
WHERE
{predicate}
"""
return h.finalize_relation(conn, sql, materialize=True)
return s.finalize_relation(conn, sql, materialize=True)


def compute_unmatched_keys(
conn: h.VersusConn,
handles: Mapping[str, h._TableHandle],
conn: t.VersusConn,
handles: Mapping[str, t._TableHandle],
table_id: Tuple[str, str],
by_columns: List[str],
materialize: bool,
Expand All @@ -256,33 +259,33 @@ def key_part(identifier: str) -> str:
other = table_id[1] if identifier == table_id[0] else table_id[0]
handle_left = handles[identifier]
handle_right = handles[other]
select_by = h.select_cols(by_columns, alias="left_tbl")
condition = h.join_condition(by_columns, "left_tbl", "right_tbl")
select_by = q.select_cols(by_columns, alias="left_tbl")
condition = q.join_condition(by_columns, "left_tbl", "right_tbl")
return f"""
SELECT
{h.sql_literal(identifier)} AS table_name,
{q.sql_literal(identifier)} AS table_name,
{select_by}
FROM
{h.table_ref(handle_left)} AS left_tbl
ANTI JOIN {h.table_ref(handle_right)} AS right_tbl
{q.table_ref(handle_left)} AS left_tbl
ANTI JOIN {q.table_ref(handle_right)} AS right_tbl
ON {condition}
"""

keys_parts = [key_part(identifier) for identifier in table_id]
unmatched_keys_sql = " UNION ALL ".join(keys_parts)
return h.finalize_relation(conn, unmatched_keys_sql, materialize)
return s.finalize_relation(conn, unmatched_keys_sql, materialize)


def compute_unmatched_rows_summary(
conn: h.VersusConn,
conn: t.VersusConn,
unmatched_keys: duckdb.DuckDBPyRelation,
table_id: Tuple[str, str],
materialize: bool,
) -> Tuple[duckdb.DuckDBPyRelation, Optional[Dict[str, int]]]:
unmatched_keys_sql = unmatched_keys.sql_query()
table_col = h.ident("table_name")
count_col = h.ident("n_unmatched")
base_sql = h.rows_relation_sql(
table_col = q.ident("table_name")
count_col = q.ident("n_unmatched")
base_sql = s.rows_relation_sql(
[(table_id[0],), (table_id[1],)], [("table_name", "VARCHAR")]
)
counts_sql = f"""
Expand All @@ -296,8 +299,8 @@ def compute_unmatched_rows_summary(
"""
order_case = (
f"CASE base.{table_col} "
f"WHEN {h.sql_literal(table_id[0])} THEN 0 "
f"WHEN {h.sql_literal(table_id[1])} THEN 1 "
f"WHEN {q.sql_literal(table_id[0])} THEN 0 "
f"WHEN {q.sql_literal(table_id[1])} THEN 1 "
"ELSE 2 END"
)
sql = f"""
Expand All @@ -311,7 +314,7 @@ def compute_unmatched_rows_summary(
ORDER BY
{order_case}
"""
relation = h.finalize_relation(conn, sql, materialize)
relation = s.finalize_relation(conn, sql, materialize)
if not materialize:
return relation, None
return relation, h.unmatched_lookup_from_rows(relation)
return relation, r.unmatched_lookup_from_rows(relation)
Loading