From d2d144d35e6fccb201b25e4039009a3338230b8b Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Mon, 27 Mar 2023 13:06:26 +0200 Subject: [PATCH 1/2] feat: allow calling `table.correlation_heatmap` with non-numerical columns --- src/safeds/data/tabular/containers/_table.py | 26 +++++++------------ .../_table/test_correlation_heatmap.py | 11 ++++---- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index b0e4dac20..6795516c2 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -311,7 +311,7 @@ def get_column(self, column_name: str) -> Column: if self.schema.has_column(column_name): output_column = Column( self._data.iloc[ - :, [self.schema._get_column_index_by_name(column_name)] + :, [self.schema._get_column_index_by_name(column_name)] ].squeeze(), column_name, self.schema.get_type_of_column(column_name), @@ -732,9 +732,9 @@ def get_type_of_column(self, column_name: str) -> ColumnType: def sort_columns( self, query: Callable[[Column, Column], int] = lambda col1, col2: ( - col1.name > col2.name - ) - - (col1.name < col2.name), + col1.name > col2.name + ) + - (col1.name < col2.name), ) -> Table: """ Sort a table with the given lambda function. @@ -975,22 +975,16 @@ def shuffle(self) -> Table: def correlation_heatmap(self) -> None: """ - Plot a correlation heatmap of an entire table. This function can only plot real numerical data. - - Raises - ------- - TypeError - If the table contains non-numerical data or complex data. + Plot a correlation heatmap for all numerical columns of this `Table`. """ - for column in self.to_columns(): - if not column.type.is_numeric(): - raise NonNumericColumnError(column.name) + only_numerical = Table.from_columns(self.list_columns_with_numerical_values()) + sns.heatmap( - data=self._data.corr(), + data=only_numerical._data.corr(), vmin=-1, vmax=1, - xticklabels=self.get_column_names(), - yticklabels=self.get_column_names(), + xticklabels=only_numerical.get_column_names(), + yticklabels=only_numerical.get_column_names(), cmap="vlag", ) plt.tight_layout() diff --git a/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py index 8f3b94746..a7d0246d0 100644 --- a/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py +++ b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py @@ -1,15 +1,14 @@ import _pytest import matplotlib.pyplot as plt import pandas as pd -import pytest + from safeds.data.tabular.containers import Table -from safeds.exceptions import NonNumericColumnError -def test_correlation_heatmap_non_numeric() -> None: - with pytest.raises(NonNumericColumnError): - table = Table(pd.DataFrame(data={"A": [1, 2, "A"], "B": [1, 2, "A"]})) - table.correlation_heatmap() +def test_correlation_heatmap_non_numeric(monkeypatch: _pytest.monkeypatch) -> None: + monkeypatch.setattr(plt, "show", lambda: None) + table = Table(pd.DataFrame(data={"A": [1, 2, "A"], "B": [1, 2, 3]})) + table.correlation_heatmap() def test_correlation_heatmap(monkeypatch: _pytest.monkeypatch) -> None: From 4f3d66ceff22778f70b410f19ca44e81a9128cc1 Mon Sep 17 00:00:00 2001 From: lars-reimann Date: Mon, 27 Mar 2023 11:10:44 +0000 Subject: [PATCH 2/2] style: apply automated linter fixes --- src/safeds/data/tabular/containers/_table.py | 8 ++++---- .../tabular/containers/_table/test_correlation_heatmap.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index 6795516c2..34bcbaa2b 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -311,7 +311,7 @@ def get_column(self, column_name: str) -> Column: if self.schema.has_column(column_name): output_column = Column( self._data.iloc[ - :, [self.schema._get_column_index_by_name(column_name)] + :, [self.schema._get_column_index_by_name(column_name)] ].squeeze(), column_name, self.schema.get_type_of_column(column_name), @@ -732,9 +732,9 @@ def get_type_of_column(self, column_name: str) -> ColumnType: def sort_columns( self, query: Callable[[Column, Column], int] = lambda col1, col2: ( - col1.name > col2.name - ) - - (col1.name < col2.name), + col1.name > col2.name + ) + - (col1.name < col2.name), ) -> Table: """ Sort a table with the given lambda function. diff --git a/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py index a7d0246d0..747dd3714 100644 --- a/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py +++ b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py @@ -1,7 +1,6 @@ import _pytest import matplotlib.pyplot as plt import pandas as pd - from safeds.data.tabular.containers import Table