diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 8ba41ad3f6..0f1a609eed 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -8,7 +8,8 @@ import logging import textwrap from pathlib import Path - +import pandas as pd +import numpy as np from hyperscript import h from rich.console import Console as RichConsole from rich.live import Live @@ -1904,10 +1905,15 @@ def show_row_diff( # Create a table with the joined keys and comparison columns column_table = row_diff.joined_sample[keys + [source_column, target_column]] - # Filter out identical-valued rows + # Filter to retain non identical-valued rows column_table = column_table[ - column_table[source_column] != column_table[target_column] + column_table.apply( + lambda row: not _cells_match(row[source_column], row[target_column]), + axis=1, + ) ] + + # Rename the column headers for readability column_table = column_table.rename( columns={ source_column: source_name, @@ -1921,7 +1927,16 @@ def show_row_diff( table.add_column(column_name, style=style, header_style=style) for _, row in column_table.iterrows(): - table.add_row(*[str(cell) for cell in row]) + table.add_row( + *[ + str( + round(cell, row_diff.decimals) + if isinstance(cell, float) + else cell + ) + for cell in row + ] + ) self.console.print( f"Column: [underline][bold cyan]{column}[/bold cyan][/underline]", @@ -2027,6 +2042,16 @@ def show_linter_violations( self.log_warning(msg) +def _cells_match(x: t.Any, y: t.Any) -> bool: + """Helper function to compare two cells and returns true if they're equal, handling array objects.""" + + # Convert array-like objects to list for consistent comparison + def _normalize(val: t.Any) -> t.Any: + return list(val) if isinstance(val, (pd.Series, np.ndarray)) else val + + return _normalize(x) == _normalize(y) + + def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget: """Helper function to add a widget to a layout widget. diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index d3e26d173d..391e77dbd2 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -70,6 +70,7 @@ class RowDiff(PydanticModel, frozen=True): source_alias: t.Optional[str] = None target_alias: t.Optional[str] = None model_name: t.Optional[str] = None + decimals: int = 3 @property def source_count(self) -> int: @@ -576,5 +577,6 @@ def name(e: exp.Expression) -> str: source_alias=self.source_alias, target_alias=self.target_alias, model_name=self.model_name, + decimals=self.decimals, ) return self._row_diff diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index 38452529ee..b7844d7d18 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -3,10 +3,49 @@ import pandas as pd from sqlglot import exp from sqlmesh.core import dialect as d +import re +import typing as t +from io import StringIO +from rich.console import Console +from sqlmesh.core.console import TerminalConsole from sqlmesh.core.context import Context from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig from sqlmesh.core.model import SqlModel, load_sql_based_model from sqlmesh.core.table_diff import TableDiff +import numpy as np + + +def create_test_console() -> t.Tuple[StringIO, TerminalConsole]: + """Creates a console and buffer for validating console output.""" + console_output = StringIO() + console = Console(file=console_output, force_terminal=True) + terminal_console = TerminalConsole(console=console) + return console_output, terminal_console + + +def capture_console_output(method_name: str, **kwargs) -> str: + """Factory function to invoke and capture output a TerminalConsole method. + + Args: + method_name: Name of the TerminalConsole method to call + **kwargs: Arguments to pass to the method + + Returns: + The captured output as a string + """ + console_output, terminal_console = create_test_console() + try: + method = getattr(terminal_console, method_name) + method(**kwargs) + return console_output.getvalue() + finally: + console_output.close() + + +def strip_ansi_codes(text: str) -> str: + """Strip ANSI color codes and styling from text.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") + return ansi_escape.sub("", text).strip() @pytest.mark.slow @@ -121,7 +160,7 @@ def test_data_diff_decimals(sushi_context_fixed_date): pd.DataFrame( { "key": [1, 2, 3], - "value": [1.0, 2.0, 3.1234], + "value": [1.0, 2.0, 3.1234321], } ), ) @@ -162,6 +201,32 @@ def test_data_diff_decimals(sushi_context_fixed_date): assert "DEV__value" in aliased_joined_sample assert "PROD__value" in aliased_joined_sample + output = capture_console_output("show_row_diff", row_diff=table_diff.row_diff()) + + # Expected output with box-drawings + expected_output = r""" +Row Counts: +├── FULL MATCH: 2 rows (66.67%) +└── PARTIAL MATCH: 1 rows (33.33%) + +COMMON ROWS column comparison stats: + pct_match +value 66.666667 + + +COMMON ROWS sample data differences: +Column: value +┏━━━━━┳━━━━━━━━┳━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━╇━━━━━━━━┩ +│ 3.0 │ 3.1233 │ 3.1234 │ +└─────┴────────┴────────┘ +""" + + stripped_output = strip_ansi_codes(output) + stripped_expected = expected_output.strip() + assert stripped_output == stripped_expected + @pytest.mark.slow def test_grain_check(sushi_context_fixed_date): @@ -363,3 +428,86 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context) _, _, col_names = table_diff.key_columns assert col_names == ["waiter_id", "event_date"] + + +@pytest.mark.slow +def test_data_diff_array_dict(sushi_context_fixed_date): + engine_adapter = sushi_context_fixed_date.engine_adapter + + engine_adapter.ctas( + "table_diff_source", + pd.DataFrame( + { + "key": [1, 2, 3], + "value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])], + "dict": [{"key1": 10, "key2": 20, "key3": 30}, {"key1": 10}, {}], + } + ), + ) + + engine_adapter.ctas( + "table_diff_target", + pd.DataFrame( + { + "key": [1, 2, 3], + "value": [ + np.array([51.2, 4.5679]), + np.array([2.31, 12.2, 3.6, 1.9]), + np.array([5.0]), + ], + "dict": [{"key1": 10, "key2": 13}, {"key1": 10}, {}], + } + ), + ) + + table_diff = TableDiff( + adapter=engine_adapter, + source="table_diff_source", + target="table_diff_target", + source_alias="dev", + target_alias="prod", + on=["key"], + decimals=4, + ) + + diff = table_diff.row_diff() + aliased_joined_sample = diff.joined_sample.columns + + assert "DEV__value" in aliased_joined_sample + assert "PROD__value" in aliased_joined_sample + assert diff.full_match_count == 1 + assert diff.partial_match_count == 2 + + output = capture_console_output("show_row_diff", row_diff=diff) + + # Expected output with boxes + expected_output = r""" +Row Counts: +├── FULL MATCH: 1 rows (33.33%) +└── PARTIAL MATCH: 2 rows (66.67%) + +COMMON ROWS column comparison stats: + pct_match +value 33.333333 +dict 66.666667 + + +COMMON ROWS sample data differences: +Column: value +┏━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │ +│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │ +└─────┴────────────────┴────────────────────────┘ +Column: dict +┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ {key1=10, key2=20, key3=30} │ {key1=10, key2=13} │ +└─────┴─────────────────────────────┴────────────────────┘ +""" + + stripped_output = strip_ansi_codes(output) + stripped_expected = expected_output.strip() + assert stripped_output == stripped_expected