diff --git a/src/safeds/_utils/__init__.py b/src/safeds/_utils/__init__.py new file mode 100644 index 000000000..3701f79be --- /dev/null +++ b/src/safeds/_utils/__init__.py @@ -0,0 +1,5 @@ +"""Tools to work internally with plots.""" + +from ._plotting import _create_image_for_plot + +__all__ = ["_create_image_for_plot"] diff --git a/src/safeds/_utils/_plotting.py b/src/safeds/_utils/_plotting.py new file mode 100644 index 000000000..5c60b361b --- /dev/null +++ b/src/safeds/_utils/_plotting.py @@ -0,0 +1,13 @@ +import io + +import matplotlib.pyplot as plt + +from safeds.data.image.containers import Image + + +def _create_image_for_plot(fig: plt.Figure) -> Image: + buffer = io.BytesIO() + fig.savefig(buffer, format="png") + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + return Image.from_bytes(buffer.read()) diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index a189e034d..04de6d162 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -1,8 +1,7 @@ from __future__ import annotations -import sys import functools -import io +import sys import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar @@ -16,7 +15,6 @@ from pandas import DataFrame from scipy import stats -from safeds.data.image.containers import Image from safeds.data.tabular.typing import ColumnType, Schema from safeds.exceptions import ( ColumnLengthMismatchError, @@ -34,6 +32,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence + from safeds.data.image.containers import Image from safeds.data.tabular.transformation import InvertibleTableTransformer, TableTransformer from ._tagged_table import TaggedTable @@ -1933,12 +1932,9 @@ def plot_correlation_heatmap(self) -> Image: cmap="vlag", ) plt.tight_layout() + from safeds._utils._plotting import _create_image_for_plot - buffer = io.BytesIO() - fig.savefig(buffer, format="png") - plt.close() # Prevents the figure from being displayed directly - buffer.seek(0) - return Image.from_bytes(buffer.read()) + return _create_image_for_plot(fig) def plot_lineplot(self, x_column_name: str, y_column_name: str) -> Image: """ @@ -1994,12 +1990,9 @@ def plot_lineplot(self, x_column_name: str, y_column_name: str) -> Image: horizontalalignment="right", ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels plt.tight_layout() + from safeds._utils._plotting import _create_image_for_plot - buffer = io.BytesIO() - fig.savefig(buffer, format="png") - plt.close() # Prevents the figure from being displayed directly - buffer.seek(0) - return Image.from_bytes(buffer.read()) + return _create_image_for_plot(fig) def plot_scatterplot(self, x_column_name: str, y_column_name: str) -> Image: """ @@ -2052,12 +2045,9 @@ def plot_scatterplot(self, x_column_name: str, y_column_name: str) -> Image: horizontalalignment="right", ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels plt.tight_layout() + from safeds._utils._plotting import _create_image_for_plot - buffer = io.BytesIO() - fig.savefig(buffer, format="png") - plt.close() # Prevents the figure from being displayed directly - buffer.seek(0) - return Image.from_bytes(buffer.read()) + return _create_image_for_plot(fig) def plot_boxplots(self) -> Image: """ @@ -2099,12 +2089,9 @@ def plot_boxplots(self) -> Image: axes.set_xticks([]) plt.tight_layout() fig = grid.fig + from safeds._utils._plotting import _create_image_for_plot - buffer = io.BytesIO() - fig.savefig(buffer, format="png") - plt.close() # Prevents the figure from being displayed directly - buffer.seek(0) - return Image.from_bytes(buffer.read()) + return _create_image_for_plot(fig) def plot_histograms(self) -> Image: """ @@ -2134,12 +2121,9 @@ def plot_histograms(self) -> Image: axes.set_xticklabels(axes.get_xticklabels(), rotation=45, horizontalalignment="right") grid.tight_layout() fig = grid.fig + from safeds._utils._plotting import _create_image_for_plot - buffer = io.BytesIO() - fig.savefig(buffer, format="png") - plt.close() - buffer.seek(0) - return Image.from_bytes(buffer.read()) + return _create_image_for_plot(fig) # ------------------------------------------------------------------------------------------------------------------ # Conversion diff --git a/src/safeds/data/tabular/containers/_time_series.py b/src/safeds/data/tabular/containers/_time_series.py index fee39fd94..c1a1f8fcc 100644 --- a/src/safeds/data/tabular/containers/_time_series.py +++ b/src/safeds/data/tabular/containers/_time_series.py @@ -1,13 +1,10 @@ from __future__ import annotations -import io import sys from typing import TYPE_CHECKING -import matplotlib.pyplot as plt import pandas as pd -from safeds.data.image.containers import Image from safeds.data.tabular.containers import Column, Row, Table, TaggedTable from safeds.exceptions import ( ColumnIsTargetError, @@ -21,6 +18,8 @@ from collections.abc import Callable, Mapping, Sequence from typing import Any + from safeds.data.image.containers import Image + class TimeSeries(TaggedTable): @@ -871,7 +870,7 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Time time_name=self.time.name, ) - def plot_lagplot(self, lag: int) -> Image: + def plot_lag_plot(self, lag: int) -> Image: """ Plot a lagplot for the target column. @@ -894,15 +893,69 @@ def plot_lagplot(self, lag: int) -> Image: -------- >>> from safeds.data.tabular.containers import TimeSeries >>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], ) - >>> image = table.plot_lagplot(lag = 1) + >>> image = table.plot_lag_plot(lag = 1) """ if not self.target.type.is_numeric(): raise NonNumericColumnError("This time series target contains non-numerical columns.") ax = pd.plotting.lag_plot(self.target._data, lag=lag) - fig = ax.figure - buffer = io.BytesIO() - fig.savefig(buffer, format="png") - plt.close() # Prevents the figure from being displayed directly - buffer.seek(0) - return Image.from_bytes(buffer.read()) + from safeds._utils._plotting import _create_image_for_plot + + return _create_image_for_plot(ax.figure) + + def plot_moving_average( + self, + window_size: int, + column_name: str | None = None, + ) -> Image: + """ + Plot the moving average for the target column. + + Parameters + ---------- + window_size: + The size of the windows, which the average gets calculated for + + column_name: + The name of the column which will be used to calculate the moving average, if None the target column will be taken + + Returns + ------- + plot: + The moving avereage plot and the normal plot as an image. + + Raises + ------ + NonNumericColumnError + If the time series targets contains non-numerical values. + + UnknownColumnNameError + If the time series doesn't contain the given column name + + Examples + -------- + >>> from safeds.data.tabular.containers import TimeSeries + >>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], ) + >>> image = table.plot_moving_average(window_size = 2) + + """ + if column_name is None or column_name == self.target.name: + series = self.target._data + column_name = self.target.name + else: + if column_name not in self.column_names: + raise UnknownColumnNameError([column_name]) + series = self._data[column_name] + if not self.get_column(column_name).type.is_numeric(): + raise NonNumericColumnError("This time series plotted column contains non-numerical columns.") + + # create moving average series + series_mvg = series.rolling(window_size).mean() + + # plot both series and put them together + ax_temp = series_mvg.plot() + ax = series.plot(ax=ax_temp) + ax.legend(labels=["moving_average", column_name]) + from safeds._utils._plotting import _create_image_for_plot + + return _create_image_for_plot(ax.figure) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lag/test_should_return_table.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lag_plot/test_should_return_table.png similarity index 100% rename from tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lag/test_should_return_table.png rename to tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lag_plot/test_should_return_table.png diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_moving_average/test_optional_parameter.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_moving_average/test_optional_parameter.png new file mode 100644 index 000000000..2ff964a56 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_moving_average/test_optional_parameter.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_moving_average/test_should_return_table.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_moving_average/test_should_return_table.png new file mode 100644 index 000000000..c88e4b05a Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_moving_average/test_should_return_table.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag_plot.py similarity index 94% rename from tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag.py rename to tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag_plot.py index cb6c94809..9a8976687 100644 --- a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag.py +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag_plot.py @@ -15,7 +15,7 @@ def test_should_return_table(snapshot_png: SnapshotAssertion) -> None: time_name="time", feature_names=None, ) - lag_plot = table.plot_lagplot(lag=1) + lag_plot = table.plot_lag_plot(lag=1) assert lag_plot == snapshot_png @@ -38,4 +38,4 @@ def test_should_raise_if_column_contains_non_numerical_values() -> None: r" non-numerical columns." ), ): - table.plot_lagplot(2) + table.plot_lag_plot(2) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_moving_average.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_moving_average.py new file mode 100644 index 000000000..7e238867b --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_moving_average.py @@ -0,0 +1,109 @@ +import pytest +from safeds.data.tabular.containers import TimeSeries +from safeds.exceptions import NonNumericColumnError, UnknownColumnNameError +from syrupy import SnapshotAssertion + + +def test_should_return_table(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 3, 2, 1, 2, 3, 4], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + moving_average_plot = table.plot_moving_average(window_size=2) + assert moving_average_plot == snapshot_png + + +def test_optional_parameter(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 4, 3, 2, 1, 0], + "target": [1, 2, 3, 4, 3, 2, 1, 2, 3, 4], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + moving_average_plot = table.plot_moving_average(window_size=2, column_name="feature_1") + assert moving_average_plot == snapshot_png + + +def test_should_raise_if_column_contains_non_numerical_values() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + NonNumericColumnError, + match=( + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series plotted" + r" column" + r" contains" + r" non-numerical columns." + ), + ): + table.plot_moving_average(2) + + +@pytest.mark.parametrize( + ("time_series", "name", "error", "error_msg"), + [ + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_1", + NonNumericColumnError, + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series plotted" + r" column" + r" contains" + r" non-numerical columns.", + ), + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_3", + UnknownColumnNameError, + r"Could not find column\(s\) 'feature_3'.", + ), + ], + ids=["feature_not_numerical", "feature_does_not_exist"], +) +def test_should_raise_error_optional_parameter( + time_series: TimeSeries, + name: str, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises( + error, + match=error_msg, + ): + time_series.plot_moving_average(2, column_name=name)