diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index 1c9807cfa1..8694ec6024 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -355,11 +355,12 @@ def _to_hashable(x: t.Any) -> t.Any: for df in _split_df_by_column_pairs(diff) ) else: - from pandas import MultiIndex + from pandas import DataFrame, MultiIndex levels = t.cast(MultiIndex, diff.columns).levels[0] for col in levels: - col_diff = diff[col] + # diff[col] returns a DataFrame when columns is a MultiIndex + col_diff = t.cast(DataFrame, diff[col]) if not col_diff.empty: table = df_to_table( f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",