diff --git a/src/safeds/ml/classical/classification/_decision_tree_classifier.py b/src/safeds/ml/classical/classification/_decision_tree_classifier.py index 9821573d6..caa20b030 100644 --- a/src/safeds/ml/classical/classification/_decision_tree_classifier.py +++ b/src/safeds/ml/classical/classification/_decision_tree_classifier.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING from safeds._utils import _structural_hash +from safeds.data.image.containers import Image +from safeds.exceptions._ml import ModelNotFittedError from safeds.ml.classical._bases import _DecisionTreeBase from ._classifier import Classifier @@ -71,3 +73,41 @@ def _get_sklearn_model(self) -> ClassifierMixin: max_depth=self._max_depth, min_samples_leaf=self._min_sample_count_in_leaves, ) + + # ------------------------------------------------------------------------------------------------------------------ + # Plot + # ------------------------------------------------------------------------------------------------------------------ + + def plot(self) -> Image: + """ + Get the image of the decision tree. + + Returns + ------- + plot: + The decision tree figure as an image. + + Raises + ------ + ModelNotFittedError: + If model is not fitted. + """ + if not self.is_fitted: + raise ModelNotFittedError + + from io import BytesIO + + import matplotlib.pyplot as plt + from sklearn.tree import plot_tree + + plot_tree(self._wrapped_model) + + # save plot fig bytes in buffer + with BytesIO() as buffer: + plt.savefig(buffer) + image = buffer.getvalue() + + # prevent forced plot from sklearn showing + plt.close() + + return Image.from_bytes(image) diff --git a/src/safeds/ml/classical/regression/_decision_tree_regressor.py b/src/safeds/ml/classical/regression/_decision_tree_regressor.py index 24ed8565c..19ea07400 100644 --- a/src/safeds/ml/classical/regression/_decision_tree_regressor.py +++ b/src/safeds/ml/classical/regression/_decision_tree_regressor.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING from safeds._utils import _structural_hash +from safeds.data.image.containers import Image +from safeds.exceptions._ml import ModelNotFittedError from safeds.ml.classical._bases import _DecisionTreeBase from ._regressor import Regressor @@ -71,3 +73,41 @@ def _get_sklearn_model(self) -> RegressorMixin: max_depth=self._max_depth, min_samples_leaf=self._min_sample_count_in_leaves, ) + + # ------------------------------------------------------------------------------------------------------------------ + # Plot + # ------------------------------------------------------------------------------------------------------------------ + + def plot(self) -> Image: + """ + Get the image of the decision tree. + + Returns + ------- + plot: + The decision tree figure as an image. + + Raises + ------ + ModelNotFittedError: + If model is not fitted. + """ + if not self.is_fitted: + raise ModelNotFittedError + + from io import BytesIO + + import matplotlib.pyplot as plt + from sklearn.tree import plot_tree + + plot_tree(self._wrapped_model) + + # save plot fig bytes in buffer + with BytesIO() as buffer: + plt.savefig(buffer) + image = buffer.getvalue() + + # prevent forced plot from sklearn showing + plt.close() + + return Image.from_bytes(image) diff --git a/tests/safeds/ml/classical/classification/__snapshots__/test_decision_tree/TestPlot.test_should_check_that_plot_image_is_same_as_snapshot.png b/tests/safeds/ml/classical/classification/__snapshots__/test_decision_tree/TestPlot.test_should_check_that_plot_image_is_same_as_snapshot.png new file mode 100644 index 000000000..70c1274b1 Binary files /dev/null and b/tests/safeds/ml/classical/classification/__snapshots__/test_decision_tree/TestPlot.test_should_check_that_plot_image_is_same_as_snapshot.png differ diff --git a/tests/safeds/ml/classical/classification/test_decision_tree.py b/tests/safeds/ml/classical/classification/test_decision_tree.py index f1c35c6be..6642ac0a6 100644 --- a/tests/safeds/ml/classical/classification/test_decision_tree.py +++ b/tests/safeds/ml/classical/classification/test_decision_tree.py @@ -1,8 +1,11 @@ import pytest from safeds.data.labeled.containers import TabularDataset from safeds.data.tabular.containers import Table -from safeds.exceptions import OutOfBoundsError +from safeds.exceptions import ModelNotFittedError, OutOfBoundsError from safeds.ml.classical.classification import DecisionTreeClassifier +from syrupy import SnapshotAssertion + +from tests.helpers import os_mac, skip_if_os @pytest.fixture() @@ -41,3 +44,21 @@ def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None def test_should_raise_if_less_than_or_equal_to_0(self, min_sample_count_in_leaves: int) -> None: with pytest.raises(OutOfBoundsError): DecisionTreeClassifier(min_sample_count_in_leaves=min_sample_count_in_leaves) + + +class TestPlot: + def test_should_raise_if_model_is_not_fitted(self) -> None: + model = DecisionTreeClassifier() + with pytest.raises(ModelNotFittedError): + model.plot() + + def test_should_check_that_plot_image_is_same_as_snapshot( + self, + training_set: TabularDataset, + snapshot_png_image: SnapshotAssertion, + ) -> None: + skip_if_os([os_mac]) + + fitted_model = DecisionTreeClassifier().fit(training_set) + image = fitted_model.plot() + assert image == snapshot_png_image diff --git a/tests/safeds/ml/classical/regression/__snapshots__/test_decision_tree/TestPlot.test_should_check_that_plot_image_is_same_as_snapshot.png b/tests/safeds/ml/classical/regression/__snapshots__/test_decision_tree/TestPlot.test_should_check_that_plot_image_is_same_as_snapshot.png new file mode 100644 index 000000000..2479862c3 Binary files /dev/null and b/tests/safeds/ml/classical/regression/__snapshots__/test_decision_tree/TestPlot.test_should_check_that_plot_image_is_same_as_snapshot.png differ diff --git a/tests/safeds/ml/classical/regression/test_decision_tree.py b/tests/safeds/ml/classical/regression/test_decision_tree.py index 0cf2beb20..6a39e1968 100644 --- a/tests/safeds/ml/classical/regression/test_decision_tree.py +++ b/tests/safeds/ml/classical/regression/test_decision_tree.py @@ -1,8 +1,9 @@ import pytest from safeds.data.labeled.containers import TabularDataset from safeds.data.tabular.containers import Table -from safeds.exceptions import OutOfBoundsError +from safeds.exceptions import ModelNotFittedError, OutOfBoundsError from safeds.ml.classical.regression import DecisionTreeRegressor +from syrupy import SnapshotAssertion @pytest.fixture() @@ -41,3 +42,19 @@ def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None def test_should_raise_if_less_than_or_equal_to_0(self, min_sample_count_in_leaves: int) -> None: with pytest.raises(OutOfBoundsError): DecisionTreeRegressor(min_sample_count_in_leaves=min_sample_count_in_leaves) + + +class TestPlot: + def test_should_raise_if_model_is_not_fitted(self) -> None: + model = DecisionTreeRegressor() + with pytest.raises(ModelNotFittedError): + model.plot() + + def test_should_check_that_plot_image_is_same_as_snapshot( + self, + training_set: TabularDataset, + snapshot_png_image: SnapshotAssertion, + ) -> None: + fitted_model = DecisionTreeRegressor().fit(training_set) + image = fitted_model.plot() + assert image == snapshot_png_image