Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import logging
import textwrap
from pathlib import Path

import pandas as pd
import numpy as np
from hyperscript import h
from rich.console import Console as RichConsole
from rich.live import Live
Expand Down Expand Up @@ -1904,10 +1905,15 @@ def show_row_diff(
# Create a table with the joined keys and comparison columns
column_table = row_diff.joined_sample[keys + [source_column, target_column]]

# Filter out identical-valued rows
# Filter to retain non identical-valued rows
column_table = column_table[
column_table[source_column] != column_table[target_column]
column_table.apply(
lambda row: not _cells_match(row[source_column], row[target_column]),
axis=1,
)
]

# Rename the column headers for readability
column_table = column_table.rename(
columns={
source_column: source_name,
Expand All @@ -1921,7 +1927,16 @@ def show_row_diff(
table.add_column(column_name, style=style, header_style=style)

for _, row in column_table.iterrows():
table.add_row(*[str(cell) for cell in row])
table.add_row(
*[
str(
round(cell, row_diff.decimals)
if isinstance(cell, float)
else cell
)
for cell in row
]
)

self.console.print(
f"Column: [underline][bold cyan]{column}[/bold cyan][/underline]",
Expand Down Expand Up @@ -2027,6 +2042,16 @@ def show_linter_violations(
self.log_warning(msg)


def _cells_match(x: t.Any, y: t.Any) -> bool:
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""

# Convert array-like objects to list for consistent comparison
def _normalize(val: t.Any) -> t.Any:
return list(val) if isinstance(val, (pd.Series, np.ndarray)) else val

return _normalize(x) == _normalize(y)


def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget:
"""Helper function to add a widget to a layout widget.

Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class RowDiff(PydanticModel, frozen=True):
source_alias: t.Optional[str] = None
target_alias: t.Optional[str] = None
model_name: t.Optional[str] = None
decimals: int = 3

@property
def source_count(self) -> int:
Expand Down Expand Up @@ -576,5 +577,6 @@ def name(e: exp.Expression) -> str:
source_alias=self.source_alias,
target_alias=self.target_alias,
model_name=self.model_name,
decimals=self.decimals,
)
return self._row_diff
150 changes: 149 additions & 1 deletion tests/core/test_table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,49 @@
import pandas as pd
from sqlglot import exp
from sqlmesh.core import dialect as d
import re
import typing as t
from io import StringIO
from rich.console import Console
from sqlmesh.core.console import TerminalConsole
from sqlmesh.core.context import Context
from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig
from sqlmesh.core.model import SqlModel, load_sql_based_model
from sqlmesh.core.table_diff import TableDiff
import numpy as np


def create_test_console() -> t.Tuple[StringIO, TerminalConsole]:
"""Creates a console and buffer for validating console output."""
console_output = StringIO()
console = Console(file=console_output, force_terminal=True)
terminal_console = TerminalConsole(console=console)
return console_output, terminal_console


def capture_console_output(method_name: str, **kwargs) -> str:
"""Factory function to invoke and capture output a TerminalConsole method.

Args:
method_name: Name of the TerminalConsole method to call
**kwargs: Arguments to pass to the method

Returns:
The captured output as a string
"""
console_output, terminal_console = create_test_console()
try:
method = getattr(terminal_console, method_name)
method(**kwargs)
return console_output.getvalue()
finally:
console_output.close()


def strip_ansi_codes(text: str) -> str:
"""Strip ANSI color codes and styling from text."""
ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]")
return ansi_escape.sub("", text).strip()


@pytest.mark.slow
Expand Down Expand Up @@ -121,7 +160,7 @@ def test_data_diff_decimals(sushi_context_fixed_date):
pd.DataFrame(
{
"key": [1, 2, 3],
"value": [1.0, 2.0, 3.1234],
"value": [1.0, 2.0, 3.1234321],
}
),
)
Expand Down Expand Up @@ -162,6 +201,32 @@ def test_data_diff_decimals(sushi_context_fixed_date):
assert "DEV__value" in aliased_joined_sample
assert "PROD__value" in aliased_joined_sample

output = capture_console_output("show_row_diff", row_diff=table_diff.row_diff())

# Expected output with box-drawings
expected_output = r"""
Row Counts:
├── FULL MATCH: 2 rows (66.67%)
└── PARTIAL MATCH: 1 rows (33.33%)

COMMON ROWS column comparison stats:
pct_match
value 66.666667


COMMON ROWS sample data differences:
Column: value
┏━━━━━┳━━━━━━━━┳━━━━━━━━┓
┃ key ┃ DEV ┃ PROD ┃
┡━━━━━╇━━━━━━━━╇━━━━━━━━┩
│ 3.0 │ 3.1233 │ 3.1234 │
└─────┴────────┴────────┘
"""

stripped_output = strip_ansi_codes(output)
stripped_expected = expected_output.strip()
assert stripped_output == stripped_expected


@pytest.mark.slow
def test_grain_check(sushi_context_fixed_date):
Expand Down Expand Up @@ -363,3 +428,86 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context)

_, _, col_names = table_diff.key_columns
assert col_names == ["waiter_id", "event_date"]


@pytest.mark.slow
def test_data_diff_array_dict(sushi_context_fixed_date):
engine_adapter = sushi_context_fixed_date.engine_adapter

engine_adapter.ctas(
"table_diff_source",
pd.DataFrame(
{
"key": [1, 2, 3],
"value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])],
"dict": [{"key1": 10, "key2": 20, "key3": 30}, {"key1": 10}, {}],
}
),
)

engine_adapter.ctas(
"table_diff_target",
pd.DataFrame(
{
"key": [1, 2, 3],
"value": [
np.array([51.2, 4.5679]),
np.array([2.31, 12.2, 3.6, 1.9]),
np.array([5.0]),
],
"dict": [{"key1": 10, "key2": 13}, {"key1": 10}, {}],
}
),
)

table_diff = TableDiff(
adapter=engine_adapter,
source="table_diff_source",
target="table_diff_target",
source_alias="dev",
target_alias="prod",
on=["key"],
decimals=4,
)

diff = table_diff.row_diff()
aliased_joined_sample = diff.joined_sample.columns

assert "DEV__value" in aliased_joined_sample
assert "PROD__value" in aliased_joined_sample
assert diff.full_match_count == 1
assert diff.partial_match_count == 2

output = capture_console_output("show_row_diff", row_diff=diff)

# Expected output with boxes
expected_output = r"""
Row Counts:
├── FULL MATCH: 1 rows (33.33%)
└── PARTIAL MATCH: 2 rows (66.67%)

COMMON ROWS column comparison stats:
pct_match
value 33.333333
dict 66.666667


COMMON ROWS sample data differences:
Column: value
┏━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ key ┃ DEV ┃ PROD ┃
┡━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │
│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │
└─────┴────────────────┴────────────────────────┘
Column: dict
┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
┃ key ┃ DEV ┃ PROD ┃
┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
│ 1 │ {key1=10, key2=20, key3=30} │ {key1=10, key2=13} │
└─────┴─────────────────────────────┴────────────────────┘
"""

stripped_output = strip_ansi_codes(output)
stripped_expected = expected_output.strip()
assert stripped_output == stripped_expected