From 289a4e467dec5095bfd81230e7dd19de07ddf303 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 24 Apr 2023 16:32:29 +0200 Subject: [PATCH 1/7] Fix bug when visualizing byte nodes The visualize function failed when trying to visualize models that had byte attributes. This is because the node's children would contain raw bytes, which the function doesn't know how to visualize. Therefore, children of byte and byte array nodes are now skipped (in addition to the node types already being skipped). To test this bug, I added a visualization test to the external packages like LightGBM, which make use of bytes. I'm not sure if some of the sklearn estimators could also be candidates, but it would surely be overkill to test visualizing all of them, whereas the overhead is not so big for the external packages. --- skops/io/_visualize.py | 14 ++++++- skops/io/tests/test_external.py | 65 +++++++++++++++++++++++++++------ 2 files changed, 65 insertions(+), 14 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 1aa3dca2..9df80928 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -8,11 +8,21 @@ from zipfile import ZipFile from ._audit import VALID_NODE_CHILD_TYPES, Node, get_tree -from ._general import FunctionNode, JsonNode, ListNode +from ._general import BytearrayNode, BytesNode, FunctionNode, JsonNode, ListNode from ._numpy import NdArrayNode from ._scipy import SparseMatrixNode from ._utils import LoadContext +# The children of these types are not visualized +SKIPPED_TYPES = ( + BytearrayNode, + BytesNode, + FunctionNode, + JsonNode, + NdArrayNode, + SparseMatrixNode, +) + @dataclass class NodeInfo: @@ -269,7 +279,7 @@ def walk_tree( # TODO: For better security, we should check the schema if we return early, # otherwise something nefarious could be hidden inside (however, if there # is, the node should be marked as unsafe) - if isinstance(node, (NdArrayNode, SparseMatrixNode, FunctionNode, JsonNode)): + if isinstance(node, SKIPPED_TYPES): return yield from walk_tree( diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index a8e6038f..38d927e4 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -2,12 +2,20 @@ Packages that are not builtins, standard lib, numpy, scipy, or scikit-learn. +Testing: + +- persistence of unfitted models +- persistence of fitted models +- visualization of dumped models + +with a range of hyperparameters. + """ import pytest from sklearn.datasets import make_classification, make_regression -from skops.io import dumps, loads +from skops.io import dumps, loads, visualize from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal # Default settings for generated data @@ -83,9 +91,12 @@ def test_classifier(self, lgbm, clf_data, trusted, boosting_type): X, y = clf_data estimator.fit(X, y) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("boosting_type", boosting_types) def test_regressor(self, lgbm, regr_data, trusted, boosting_type): kw = {} @@ -99,9 +110,12 @@ def test_regressor(self, lgbm, regr_data, trusted, boosting_type): X, y = regr_data estimator.fit(X, y) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("boosting_type", boosting_types) def test_ranker(self, lgbm, rank_data, trusted, boosting_type): kw = {} @@ -115,9 +129,12 @@ def test_ranker(self, lgbm, rank_data, trusted, boosting_type): X, y, group = rank_data estimator.fit(X, y, group=group) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + class TestXGBoost: """Tests for XGBClassifier, XGBRegressor, XGBRFClassifier, XGBRFRegressor, XGBRanker @@ -170,9 +187,12 @@ def test_classifier(self, xgboost, clf_data, trusted, booster, tree_method): X, y = clf_data estimator.fit(X, y) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) def test_regressor(self, xgboost, regr_data, trusted, booster, tree_method): @@ -186,9 +206,12 @@ def test_regressor(self, xgboost, regr_data, trusted, booster, tree_method): X, y = regr_data estimator.fit(X, y) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) def test_rf_classifier(self, xgboost, clf_data, trusted, booster, tree_method): @@ -202,9 +225,12 @@ def test_rf_classifier(self, xgboost, clf_data, trusted, booster, tree_method): X, y = clf_data estimator.fit(X, y) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) def test_rf_regressor(self, xgboost, regr_data, trusted, booster, tree_method): @@ -218,9 +244,12 @@ def test_rf_regressor(self, xgboost, regr_data, trusted, booster, tree_method): X, y = regr_data estimator.fit(X, y) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) def test_ranker(self, xgboost, rank_data, trusted, booster, tree_method): @@ -234,9 +263,12 @@ def test_ranker(self, xgboost, rank_data, trusted, booster, tree_method): X, y, group = rank_data estimator.fit(X, y, group=group) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + class TestCatboost: """Tests for CatBoostClassifier, CatBoostRegressor, and CatBoostRanker""" @@ -290,9 +322,12 @@ def test_classifier(self, catboost, cb_clf_data, trusted, boosting_type): X, y = cb_clf_data estimator.fit(X, y, cat_features=[0, 1]) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("boosting_type", boosting_types) def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): estimator = catboost.CatBoostRegressor( @@ -303,9 +338,12 @@ def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): X, y = cb_regr_data estimator.fit(X, y, cat_features=[0, 1]) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + visualize(dumped, trusted=trusted) + @pytest.mark.parametrize("boosting_type", boosting_types) def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): estimator = catboost.CatBoostRanker( @@ -316,5 +354,8 @@ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): X, y, group_id = cb_rank_data estimator.fit(X, y, cat_features=[0, 1], group_id=group_id) - loaded = loads(dumps(estimator), trusted=trusted) + dumped = dumps(estimator) + loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) + + visualize(dumped, trusted=trusted) From 3371ed7d5bec304c12aced8304e571762f3c3724 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 24 Apr 2023 16:55:19 +0200 Subject: [PATCH 2/7] Try fixing windows encoding errors https://github.com/Textualize/rich/issues/330#issuecomment-703246028 --- .github/workflows/build-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 1dfa0e30..abb9438b 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -79,6 +79,7 @@ jobs: - name: Tests env: SUPER_SECRET: ${{ secrets.HF_HUB_TOKEN }} + PYTHONIOENCODING: "utf-8" run: | python -m pytest -s -v --cov-report=xml -m "not inference" skops/ From ee49944bba3c19b04a6ba77356dc0bb7f37a7234 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 25 Apr 2023 12:22:37 +0200 Subject: [PATCH 3/7] Show preview of bytes in visualize Truncate the number of bytes shown at 24. --- skops/io/_general.py | 18 ++++++++++++++++++ skops/io/tests/test_visualize.py | 20 ++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/skops/io/_general.py b/skops/io/_general.py index b6b00d11..275282b6 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -5,6 +5,7 @@ import operator import uuid from functools import partial +from reprlib import Repr from types import FunctionType, MethodType from typing import Any, Sequence @@ -27,6 +28,9 @@ ) from .exceptions import UnsupportedTypeException +arepr = Repr() +arepr.maxstring = 24 + def dict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { @@ -527,6 +531,17 @@ def _construct(self): content = self.children["content"].getvalue() return content + def format(self): + try: + content = self.children["content"].getvalue() + byte_repr = arepr.repr(content) + except Exception: + byte_repr = "b'...'" + finally: + # ensure that no matter what happens, the file pointer is reset + self.children["content"].seek(0) + return byte_repr + class BytearrayNode(BytesNode): def __init__( @@ -543,6 +558,9 @@ def _construct(self): content_bytearray = bytearray(list(content_bytes)) return content_bytearray + def format(self): + return f"bytearray({super().format()})" + def operator_func_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: _, attrs = obj.__reduce__() diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 33a00436..77d2fd7d 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -269,3 +269,23 @@ def test_from_file(self, simple, tmp_path, capsys): ] stdout, _ = capsys.readouterr() assert stdout.strip() == "\n".join(expected) + + def test_long_bytes(self, capsys): + obj = { + "short_byte": b"abc", + "long_byte": b"010203040506070809101112131415", + "short_bytearray": bytearray(b"abc"), + "long_bytearray": bytearray(b"010203040506070809101112131415"), + } + dumped = sio.dumps(obj) + sio.visualize(dumped) + + expected = [ + "root: builtins.dict", + "├── short_byte: b'abc'", + "├── long_byte: b'01020304050...9101112131415'", + "├── short_bytearray: bytearray(b'abc')", + "└── long_bytearray: bytearray(b'01020304050...9101112131415')", + ] + stdout, _ = capsys.readouterr() + assert stdout.strip() == "\n".join(expected) From 5ab74453c922ad8bd20abd04eb2d6e923f0ce962 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 25 Apr 2023 12:24:53 +0200 Subject: [PATCH 4/7] Add entry to changes.rst --- docs/changes.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changes.rst b/docs/changes.rst index 295b876b..f70dcaaf 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -15,6 +15,8 @@ v0.7 - `compression` and `compresslevel` from :class:`~zipfile.ZipFile` are now exposed to the user via :func:`.io.dumps` and :func:`.io.dump`. :pr:`345` by `Adrin Jalali`_. +- Fix: :func:`skops.io.visualize` is now capable of showing bytes. :pr:`352` by + `Benjamin Bossan`_. v0.6 ---- From af522efa2e53efd33ab72bdce4e9183c52076066 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 25 Apr 2023 12:37:42 +0200 Subject: [PATCH 5/7] Simplify code Calling getvalue does not require a seek(0) afterwards. --- skops/io/_general.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 275282b6..0902857c 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -532,14 +532,8 @@ def _construct(self): return content def format(self): - try: - content = self.children["content"].getvalue() - byte_repr = arepr.repr(content) - except Exception: - byte_repr = "b'...'" - finally: - # ensure that no matter what happens, the file pointer is reset - self.children["content"].seek(0) + content = self.children["content"].getvalue() + byte_repr = arepr.repr(content) return byte_repr From 9307c772c41c924af4fbebae74d37dbf1603c55e Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Apr 2023 11:48:23 +0200 Subject: [PATCH 6/7] Prevent printing to stdout for test_external.py --- skops/io/tests/test_external.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index 38d927e4..d4dc7d08 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -54,6 +54,11 @@ def rank_data(clf_data): return X, y, group +def _null(*args, **kwargs): + # used to prevent printing anything to stdout when calling visualize + return + + class TestLightGBM: """Tests for LGBMClassifier, LGBMRegressor, LGBMRanker""" @@ -95,7 +100,7 @@ def test_classifier(self, lgbm, clf_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("boosting_type", boosting_types) def test_regressor(self, lgbm, regr_data, trusted, boosting_type): @@ -114,7 +119,7 @@ def test_regressor(self, lgbm, regr_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("boosting_type", boosting_types) def test_ranker(self, lgbm, rank_data, trusted, boosting_type): @@ -133,7 +138,7 @@ def test_ranker(self, lgbm, rank_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) class TestXGBoost: @@ -191,7 +196,7 @@ def test_classifier(self, xgboost, clf_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -210,7 +215,7 @@ def test_regressor(self, xgboost, regr_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -229,7 +234,7 @@ def test_rf_classifier(self, xgboost, clf_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -248,7 +253,7 @@ def test_rf_regressor(self, xgboost, regr_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -267,7 +272,7 @@ def test_ranker(self, xgboost, rank_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) class TestCatboost: @@ -326,7 +331,7 @@ def test_classifier(self, catboost, cb_clf_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("boosting_type", boosting_types) def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): @@ -342,7 +347,7 @@ def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) @pytest.mark.parametrize("boosting_type", boosting_types) def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): @@ -358,4 +363,4 @@ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted) + visualize(dumped, trusted=trusted, sink=_null) From 5cb3038dead776dfc85c1056cd8d815ce7f62f70 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Apr 2023 14:19:15 +0200 Subject: [PATCH 7/7] Other method to suppress printing to stdout More of the visualize stack is running this way, improving test coverage. --- skops/io/tests/test_external.py | 53 +++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index d4dc7d08..b41253c0 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -12,6 +12,8 @@ """ +from unittest.mock import Mock, patch + import pytest from sklearn.datasets import make_classification, make_regression @@ -54,14 +56,17 @@ def rank_data(clf_data): return X, y, group -def _null(*args, **kwargs): - # used to prevent printing anything to stdout when calling visualize - return - - class TestLightGBM: """Tests for LGBMClassifier, LGBMRegressor, LGBMRanker""" + @pytest.fixture(autouse=True) + def capture_stdout(self): + # Mock print and rich.print so that running these tests with pytest -s + # does not spam stdout. Other, more common methods of suppressing + # printing to stdout don't seem to work, perhaps because of pytest. + with patch("builtins.print", Mock()), patch("rich.print", Mock()): + yield + @pytest.fixture(autouse=True) def lgbm(self): lgbm = pytest.importorskip("lightgbm") @@ -100,7 +105,7 @@ def test_classifier(self, lgbm, clf_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("boosting_type", boosting_types) def test_regressor(self, lgbm, regr_data, trusted, boosting_type): @@ -119,7 +124,7 @@ def test_regressor(self, lgbm, regr_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("boosting_type", boosting_types) def test_ranker(self, lgbm, rank_data, trusted, boosting_type): @@ -138,7 +143,7 @@ def test_ranker(self, lgbm, rank_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) class TestXGBoost: @@ -158,6 +163,14 @@ class TestXGBoost: """ + @pytest.fixture(autouse=True) + def capture_stdout(self): + # Mock print and rich.print so that running these tests with pytest -s + # does not spam stdout. Other, more common methods of suppressing + # printing to stdout don't seem to work, perhaps because of pytest. + with patch("builtins.print", Mock()), patch("rich.print", Mock()): + yield + @pytest.fixture(autouse=True) def xgboost(self): xgboost = pytest.importorskip("xgboost") @@ -196,7 +209,7 @@ def test_classifier(self, xgboost, clf_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -215,7 +228,7 @@ def test_regressor(self, xgboost, regr_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -234,7 +247,7 @@ def test_rf_classifier(self, xgboost, clf_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -253,7 +266,7 @@ def test_rf_regressor(self, xgboost, regr_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("booster", boosters) @pytest.mark.parametrize("tree_method", tree_methods) @@ -272,12 +285,20 @@ def test_ranker(self, xgboost, rank_data, trusted, booster, tree_method): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) class TestCatboost: """Tests for CatBoostClassifier, CatBoostRegressor, and CatBoostRanker""" + @pytest.fixture(autouse=True) + def capture_stdout(self): + # Mock print and rich.print so that running these tests with pytest -s + # does not spam stdout. Other, more common methods of suppressing + # printing to stdout don't seem to work, perhaps because of pytest. + with patch("builtins.print", Mock()), patch("rich.print", Mock()): + yield + # CatBoost data is a little different so that it works as categorical data @pytest.fixture(scope="module") def cb_clf_data(self, clf_data): @@ -331,7 +352,7 @@ def test_classifier(self, catboost, cb_clf_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("boosting_type", boosting_types) def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): @@ -347,7 +368,7 @@ def test_regressor(self, catboost, cb_regr_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted) @pytest.mark.parametrize("boosting_type", boosting_types) def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): @@ -363,4 +384,4 @@ def test_ranker(self, catboost, cb_rank_data, trusted, boosting_type): loaded = loads(dumped, trusted=trusted) assert_method_outputs_equal(estimator, loaded, X) - visualize(dumped, trusted=trusted, sink=_null) + visualize(dumped, trusted=trusted)