From d5a55882cdf231c3217ce700912afdd50c1c9d33 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:29:07 +0300 Subject: [PATCH 1/4] Fix: Account for array types when showing sample in table diff --- sqlmesh/core/console.py | 45 ++++++++++- sqlmesh/core/table_diff.py | 2 + tests/core/test_table_diff.py | 141 +++++++++++++++++++++++++++++++++- 3 files changed, 183 insertions(+), 5 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 8ba41ad3f6..5e341cf186 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -8,7 +8,8 @@ import logging import textwrap from pathlib import Path - +import pandas as pd +import numpy as np from hyperscript import h from rich.console import Console as RichConsole from rich.live import Live @@ -1904,10 +1905,37 @@ def show_row_diff( # Create a table with the joined keys and comparison columns column_table = row_diff.joined_sample[keys + [source_column, target_column]] - # Filter out identical-valued rows + def compare_cells(x: t.Any, y: t.Any) -> bool: + """Compare two cells and returns true if they're not equal, handling array objects.""" + if x is None or y is None: + return x != y + + # Convert any array-like object to list for consistent comparison + def to_list(val: t.Any) -> t.Any: + return ( + list(val) + if isinstance(val, (pd.Series, np.ndarray, list, tuple, set)) + else val + ) + + x = to_list(x) + y = to_list(y) + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return True + return any(a != b for a, b in zip(x, y)) + + return x != y + + # Filter to retain non identical-valued rows column_table = column_table[ - column_table[source_column] != column_table[target_column] + column_table.apply( + lambda row: compare_cells(row[source_column], row[target_column]), + axis=1, + ) ] + + # Rename the column headers for readability column_table = column_table.rename( columns={ source_column: source_name, @@ -1921,7 +1949,16 @@ def show_row_diff( table.add_column(column_name, style=style, header_style=style) for _, row in column_table.iterrows(): - table.add_row(*[str(cell) for cell in row]) + table.add_row( + *[ + str( + round(cell, row_diff.decimals) + if isinstance(cell, float) + else cell + ) + for cell in row + ] + ) self.console.print( f"Column: [underline][bold cyan]{column}[/bold cyan][/underline]", diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index d3e26d173d..391e77dbd2 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -70,6 +70,7 @@ class RowDiff(PydanticModel, frozen=True): source_alias: t.Optional[str] = None target_alias: t.Optional[str] = None model_name: t.Optional[str] = None + decimals: int = 3 @property def source_count(self) -> int: @@ -576,5 +577,6 @@ def name(e: exp.Expression) -> str: source_alias=self.source_alias, target_alias=self.target_alias, model_name=self.model_name, + decimals=self.decimals, ) return self._row_diff diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index 38452529ee..a7a3b09b27 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -3,10 +3,49 @@ import pandas as pd from sqlglot import exp from sqlmesh.core import dialect as d +import re +import typing as t +from io import StringIO +from rich.console import Console +from sqlmesh.core.console import TerminalConsole from sqlmesh.core.context import Context from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig from sqlmesh.core.model import SqlModel, load_sql_based_model from sqlmesh.core.table_diff import TableDiff +import numpy as np + + +def create_test_console() -> t.Tuple[StringIO, TerminalConsole]: + """Creates a console and buffer for validating console output.""" + console_output = StringIO() + console = Console(file=console_output, force_terminal=True) + terminal_console = TerminalConsole(console=console) + return console_output, terminal_console + + +def capture_console_output(method_name: str, **kwargs) -> str: + """Factory function to invoke and capture output a TerminalConsole method. + + Args: + method_name: Name of the TerminalConsole method to call + **kwargs: Arguments to pass to the method + + Returns: + The captured output as a string + """ + console_output, terminal_console = create_test_console() + try: + method = getattr(terminal_console, method_name) + method(**kwargs) + return console_output.getvalue() + finally: + console_output.close() + + +def strip_ansi_codes(text: str) -> str: + """Strip ANSI color codes and styling from text.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") + return ansi_escape.sub("", text).strip() @pytest.mark.slow @@ -121,7 +160,7 @@ def test_data_diff_decimals(sushi_context_fixed_date): pd.DataFrame( { "key": [1, 2, 3], - "value": [1.0, 2.0, 3.1234], + "value": [1.0, 2.0, 3.1234321], } ), ) @@ -162,6 +201,32 @@ def test_data_diff_decimals(sushi_context_fixed_date): assert "DEV__value" in aliased_joined_sample assert "PROD__value" in aliased_joined_sample + output = capture_console_output("show_row_diff", row_diff=table_diff.row_diff()) + + # Expected output with box-drawings + expected_output = r""" +Row Counts: +├── FULL MATCH: 2 rows (66.67%) +└── PARTIAL MATCH: 1 rows (33.33%) + +COMMON ROWS column comparison stats: + pct_match +value 66.666667 + + +COMMON ROWS sample data differences: +Column: value +┏━━━━━┳━━━━━━━━┳━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━╇━━━━━━━━┩ +│ 3.0 │ 3.1233 │ 3.1234 │ +└─────┴────────┴────────┘ +""" + + stripped_output = strip_ansi_codes(output) + stripped_expected = expected_output.strip() + assert stripped_output == stripped_expected + @pytest.mark.slow def test_grain_check(sushi_context_fixed_date): @@ -363,3 +428,77 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context) _, _, col_names = table_diff.key_columns assert col_names == ["waiter_id", "event_date"] + + +@pytest.mark.slow +def test_data_diff_array(sushi_context_fixed_date): + engine_adapter = sushi_context_fixed_date.engine_adapter + + engine_adapter.ctas( + "table_diff_source", + pd.DataFrame( + { + "key": [1, 2, 3], + "value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])], + } + ), + ) + + engine_adapter.ctas( + "table_diff_target", + pd.DataFrame( + { + "key": [1, 2, 3], + "value": [ + np.array([51.2, 4.5679]), + np.array([2.31, 12.2, 3.6, 1.9]), + np.array([5.0]), + ], + } + ), + ) + + table_diff = TableDiff( + adapter=engine_adapter, + source="table_diff_source", + target="table_diff_target", + source_alias="dev", + target_alias="prod", + on=["key"], + decimals=4, + ) + + diff = table_diff.row_diff() + aliased_joined_sample = diff.joined_sample.columns + + assert "DEV__value" in aliased_joined_sample + assert "PROD__value" in aliased_joined_sample + assert diff.full_match_count == 1 + assert diff.partial_match_count == 2 + + output = capture_console_output("show_row_diff", row_diff=diff) + + # Expected output with boxes + expected_output = r""" +Row Counts: +├── FULL MATCH: 1 rows (33.33%) +└── PARTIAL MATCH: 2 rows (66.67%) + +COMMON ROWS column comparison stats: + pct_match +value 33.333333 + + +COMMON ROWS sample data differences: +Column: value +┏━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │ +│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │ +└─────┴────────────────┴────────────────────────┘ +""" + + stripped_output = strip_ansi_codes(output) + stripped_expected = expected_output.strip() + assert stripped_output == stripped_expected From fc7e68f850a5ba43d469bdc14f6b07b9d4a2a26d Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 3 Apr 2025 19:47:43 +0300 Subject: [PATCH 2/4] Refactor; extend test for dicts --- sqlmesh/core/console.py | 43 ++++++++++++++++------------------- tests/core/test_table_diff.py | 11 ++++++++- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 5e341cf186..0a9a411805 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -1905,32 +1905,10 @@ def show_row_diff( # Create a table with the joined keys and comparison columns column_table = row_diff.joined_sample[keys + [source_column, target_column]] - def compare_cells(x: t.Any, y: t.Any) -> bool: - """Compare two cells and returns true if they're not equal, handling array objects.""" - if x is None or y is None: - return x != y - - # Convert any array-like object to list for consistent comparison - def to_list(val: t.Any) -> t.Any: - return ( - list(val) - if isinstance(val, (pd.Series, np.ndarray, list, tuple, set)) - else val - ) - - x = to_list(x) - y = to_list(y) - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return True - return any(a != b for a, b in zip(x, y)) - - return x != y - # Filter to retain non identical-valued rows column_table = column_table[ column_table.apply( - lambda row: compare_cells(row[source_column], row[target_column]), + lambda row: _compare_df_cells(row[source_column], row[target_column]), axis=1, ) ] @@ -2064,6 +2042,25 @@ def show_linter_violations( self.log_warning(msg) +def _compare_df_cells(x: t.Any, y: t.Any) -> bool: + """Helper function to compare two cells and returns true if they're not equal, handling array objects.""" + if x is None or y is None: + return x != y + + # Convert any array-like object to list for consistent comparison + def to_list(val: t.Any) -> t.Any: + return list(val) if isinstance(val, (pd.Series, np.ndarray, list, tuple, set)) else val + + x = to_list(x) + y = to_list(y) + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return True + return any(a != b for a, b in zip(x, y)) + + return x != y + + def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget: """Helper function to add a widget to a layout widget. diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index a7a3b09b27..b7844d7d18 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -431,7 +431,7 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context) @pytest.mark.slow -def test_data_diff_array(sushi_context_fixed_date): +def test_data_diff_array_dict(sushi_context_fixed_date): engine_adapter = sushi_context_fixed_date.engine_adapter engine_adapter.ctas( @@ -440,6 +440,7 @@ def test_data_diff_array(sushi_context_fixed_date): { "key": [1, 2, 3], "value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])], + "dict": [{"key1": 10, "key2": 20, "key3": 30}, {"key1": 10}, {}], } ), ) @@ -454,6 +455,7 @@ def test_data_diff_array(sushi_context_fixed_date): np.array([2.31, 12.2, 3.6, 1.9]), np.array([5.0]), ], + "dict": [{"key1": 10, "key2": 13}, {"key1": 10}, {}], } ), ) @@ -487,6 +489,7 @@ def test_data_diff_array(sushi_context_fixed_date): COMMON ROWS column comparison stats: pct_match value 33.333333 +dict 66.666667 COMMON ROWS sample data differences: @@ -497,6 +500,12 @@ def test_data_diff_array(sushi_context_fixed_date): │ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │ │ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │ └─────┴────────────────┴────────────────────────┘ +Column: dict +┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ {key1=10, key2=20, key3=30} │ {key1=10, key2=13} │ +└─────┴─────────────────────────────┴────────────────────┘ """ stripped_output = strip_ansi_codes(output) From 734ef9b4a100bc9d6560570763dc4a14a1cd2693 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 3 Apr 2025 21:08:15 +0300 Subject: [PATCH 3/4] Refactor to _cells_match and simplify logic --- sqlmesh/core/console.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 0a9a411805..bcb500b837 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -1908,7 +1908,7 @@ def show_row_diff( # Filter to retain non identical-valued rows column_table = column_table[ column_table.apply( - lambda row: _compare_df_cells(row[source_column], row[target_column]), + lambda row: not _cells_match(row[source_column], row[target_column]), axis=1, ) ] @@ -2042,23 +2042,16 @@ def show_linter_violations( self.log_warning(msg) -def _compare_df_cells(x: t.Any, y: t.Any) -> bool: - """Helper function to compare two cells and returns true if they're not equal, handling array objects.""" +def _cells_match(x: t.Any, y: t.Any) -> bool: + """Helper function to compare two cells and returns true if they're equal, handling array objects.""" if x is None or y is None: - return x != y + return x == y - # Convert any array-like object to list for consistent comparison - def to_list(val: t.Any) -> t.Any: - return list(val) if isinstance(val, (pd.Series, np.ndarray, list, tuple, set)) else val + # Convert array-like objects to list for consistent comparison + def _normalize(val: t.Any) -> t.Any: + return list(val) if isinstance(val, (pd.Series, np.ndarray)) else val - x = to_list(x) - y = to_list(y) - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return True - return any(a != b for a, b in zip(x, y)) - - return x != y + return _normalize(x) == _normalize(y) def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget: From 34aef42b1a05ac079f1bd9858cc2d4e3db0f5f3c Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 3 Apr 2025 21:18:15 +0300 Subject: [PATCH 4/4] Refactor unnecessary code --- sqlmesh/core/console.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index bcb500b837..0f1a609eed 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -2044,8 +2044,6 @@ def show_linter_violations( def _cells_match(x: t.Any, y: t.Any) -> bool: """Helper function to compare two cells and returns true if they're equal, handling array objects.""" - if x is None or y is None: - return x == y # Convert array-like objects to list for consistent comparison def _normalize(val: t.Any) -> t.Any: