diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index b0e4dac20..34bcbaa2b 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -975,22 +975,16 @@ def shuffle(self) -> Table: def correlation_heatmap(self) -> None: """ - Plot a correlation heatmap of an entire table. This function can only plot real numerical data. - - Raises - ------- - TypeError - If the table contains non-numerical data or complex data. + Plot a correlation heatmap for all numerical columns of this `Table`. """ - for column in self.to_columns(): - if not column.type.is_numeric(): - raise NonNumericColumnError(column.name) + only_numerical = Table.from_columns(self.list_columns_with_numerical_values()) + sns.heatmap( - data=self._data.corr(), + data=only_numerical._data.corr(), vmin=-1, vmax=1, - xticklabels=self.get_column_names(), - yticklabels=self.get_column_names(), + xticklabels=only_numerical.get_column_names(), + yticklabels=only_numerical.get_column_names(), cmap="vlag", ) plt.tight_layout() diff --git a/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py index 8f3b94746..747dd3714 100644 --- a/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py +++ b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py @@ -1,15 +1,13 @@ import _pytest import matplotlib.pyplot as plt import pandas as pd -import pytest from safeds.data.tabular.containers import Table -from safeds.exceptions import NonNumericColumnError -def test_correlation_heatmap_non_numeric() -> None: - with pytest.raises(NonNumericColumnError): - table = Table(pd.DataFrame(data={"A": [1, 2, "A"], "B": [1, 2, "A"]})) - table.correlation_heatmap() +def test_correlation_heatmap_non_numeric(monkeypatch: _pytest.monkeypatch) -> None: + monkeypatch.setattr(plt, "show", lambda: None) + table = Table(pd.DataFrame(data={"A": [1, 2, "A"], "B": [1, 2, 3]})) + table.correlation_heatmap() def test_correlation_heatmap(monkeypatch: _pytest.monkeypatch) -> None: