diff --git a/docs/changes.rst b/docs/changes.rst index bf0d29ea..bfcadbc2 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -25,6 +25,9 @@ v0.6 - ``add_*`` methods on :class:`.Card` now have default section names (but ``None`` is no longer valid) and no longer add descriptions by default. :pr:`321` by `Benjamin Bossan`_. +- Add possibility to visualize a skops object and show untrusted types by using + :func:`skops.io.visualize`. For colored output, install `rich`: `pip install + rich`. :pr:`317` by `Benjamin Bossan`_. v0.5 ---- diff --git a/docs/persistence.rst b/docs/persistence.rst index 341b8c74..266d171b 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -134,6 +134,64 @@ For example, to convert all ``.pkl`` flies in the current directory: Further help for the different supported options can be found by calling ``skops convert --help`` in a terminal. +Visualization +############# + +Skops files can be visualized using :func:`skops.io.visualize`. If you have +a skops file called ``my-model.skops``, you can visualize it like this: + +.. code:: python + + import skops.io as sio + sio.visualize("my-model.skops") + +The output could look like this: + +.. code:: + + root: sklearn.preprocessing._data.MinMaxScaler + └── attrs: builtins.dict + ├── feature_range: builtins.tuple + │ ├── content: json-type(-555) + │ └── content: json-type(123) + ├── copy: unsafe_lib.UnsafeType [UNSAFE] + ├── clip: json-type(false) + └── _sklearn_version: json-type("1.2.0") + +``unsafe_lib.UnsafeType`` was recognized as untrusted and marked. + +It's also possible to visualize the object dumped as bytes: + + import skops.io as sio + my_model = ... + sio.visualize(sio.dumps(my_model)) + +There are various options to customize the output. By default, the security of +nodes is color coded if `rich `_ is +installed, otherwise they all have the same color. To install ``rich``, run: + +.. code:: + + python -m pip install rich + +or, when installing skops, install it like this: + + python -m pip install skops[rich] + +To disable colors, even if ``rich`` is installed, pass ``use_colors=False`` to +:func:`skops.io.visualize`. + +It's also possible to change what colors are being used, e.g. by passing +``visualize(..., color_safe="cyan")`` to change the color for trusted nodes from +green to cyan. The ``rich`` docs list the `supported standard colors +`_. + +Note that the visualization feature is intended to help understand the structure +of the object, e.g. what attributes are identified as untrusted. It is not a +replacement for a proper security check. In particular, just because an object's +visualization looks innocent does *not* mean you can just call `sio.load(, +trusted=True)` on this object -- only pass the types you really trust to the +``trusted`` argument. Supported libraries ------------------- diff --git a/setup.py b/setup.py index 658861c8..3cc6aa6a 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,7 @@ def setup_package(): extras_require={ "docs": min_deps.tag_to_packages["docs"], "tests": min_deps.tag_to_packages["tests"], + "rich": min_deps.tag_to_packages["rich"], }, include_package_data=True, ) diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index ff3483d1..d66b554f 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -6,7 +6,8 @@ # 'build' and 'install' is included to have structured metadata for CI. # It will NOT be included in setup's extras_require # The values are (version_spec, comma separated tags, condition) -# tags can be: 'build', 'install', 'docs', 'examples', 'tests', 'benchmark' +# tags can be: 'build', 'install', 'docs', 'examples', 'tests', 'benchmark', +# 'rich' # example: # "tomli": ("1.1.0", "install", "python_full_version < '3.11.0a7'"), dependent_packages = { @@ -34,13 +35,14 @@ # TODO: remove condition when catboost supports python 3.11 "catboost": ("1.0", "tests", "python_version < '3.11'"), "fairlearn": ("0.7.0", "docs, tests", None), + "rich": ("12", "tests, rich", None), } # create inverse mapping for setuptools tag_to_packages: dict = { extra: [] - for extra in ["build", "install", "docs", "examples", "tests", "benchmark"] + for extra in ["build", "install", "docs", "examples", "tests", "benchmark", "rich"] } for package, (min_version, extras, condition) in dependent_packages.items(): for extra in extras.split(", "): diff --git a/skops/conftest.py b/skops/conftest.py index 9ee0a4db..93430f83 100644 --- a/skops/conftest.py +++ b/skops/conftest.py @@ -43,3 +43,16 @@ def mock_import(name, *args, **kwargs): yield import matplotlib # noqa + + +@pytest.fixture +def rich_not_installed(): + orig_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "rich": + raise ImportError + return orig_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + yield diff --git a/skops/io/__init__.py b/skops/io/__init__.py index 7990e9eb..c60c99c6 100644 --- a/skops/io/__init__.py +++ b/skops/io/__init__.py @@ -1,3 +1,4 @@ from ._persist import dump, dumps, get_untrusted_types, load, loads +from ._visualize import visualize -__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types"] +__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types", "visualize"] diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 1e379308..5926c536 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -287,6 +287,10 @@ def get_unsafe_set(self) -> set[str]: return res + def format(self) -> str: + """Representation of the node's content.""" + return f"{self.module_name}.{self.class_name}" + class CachedNode(Node): def __init__( diff --git a/skops/io/_general.py b/skops/io/_general.py index b5685c3f..dfb7f223 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -482,6 +482,14 @@ def get_unsafe_set(self) -> set[str]: def _construct(self): return json.loads(self.content) + def format(self) -> str: + """Representation of the node's content. + + Since no module is used, just show the content. + + """ + return f"json-type({self.content})" + def bytes_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: f_name = f"{uuid.uuid4()}.bin" diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index cbf5c1d1..1c494a49 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -40,9 +40,9 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - type = state["type"] + self.type = state["type"] self.trusted = self._get_trusted(trusted, [spmatrix]) - if type != "scipy": + if self.type != "scipy": raise TypeError( f"Cannot load object of type {self.module_name}.{self.class_name}" ) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py new file mode 100644 index 00000000..953345e7 --- /dev/null +++ b/skops/io/_visualize.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +import io +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Iterator, Literal +from zipfile import ZipFile + +from ._audit import Node, get_tree +from ._general import FunctionNode, JsonNode, ListNode +from ._numpy import NdArrayNode +from ._scipy import SparseMatrixNode +from ._utils import LoadContext + + +@dataclass +class NodeInfo: + """Information pertinent for visualizatoin, extracted from ``Node``s. + + This class contains all information necessary for visualizing nodes. This + way, we can have separate functions for: + + - visiting nodes and determining their safety + - visualizing the nodes + + The visualization function will only receive the ``NodeInfo`` and does not + have to concern itself with how to discover children or determine safety. + + """ + + level: int + key: str # the key to the node + val: str # the value of the node + is_self_safe: bool # whether this specific node is safe + is_safe: bool # whether this node and all of its children are safe + is_last: bool # whether this is the last child of parent node + + +def _check_visibility( + is_self_safe: bool, + is_safe: bool, + show: Literal["all", "untrusted", "trusted"], +) -> bool: + """Determine visibility of the node. + + Users can indicate if they want to see all nodes, all trusted nodes, or all + untrusted nodes. + + """ + if show == "all": + return True + + if show == "untrusted": + return not is_safe + + # case: show only safe node + return is_self_safe + + +def _get_node_label( + node: NodeInfo, + tag_safe: str = "", + tag_unsafe: str = "[UNSAFE]", + use_colors: bool = True, + # use rich for coloring + color_safe: str = "green", + color_unsafe: str = "red", + color_child_unsafe: str = "yellow", +): + """Determine the label of a node. + + Nodes are labeled differently based on how they're trusted. + + """ + # note: when changing the arguments to this function, also update the + # docstring of visualize! + + if use_colors: + try: + import rich # noqa + except ImportError: + use_colors = False + + # add tag if necessary + node_val = node.val + tag = tag_safe if node.is_self_safe else tag_unsafe + if tag: + node_val += f" {tag}" + + # colorize if so desired and if rich is installed + if use_colors: + if node.is_safe: + color = color_safe + elif node.is_self_safe: + color = color_child_unsafe + else: + color = color_unsafe + node_val = f"[{color}]{node_val}" + + return node_val + + +def pretty_print_tree( + nodes_iter: Iterator[NodeInfo], + show: Literal["all", "untrusted", "trusted"], + **kwargs, +) -> None: + # This function loops through the flattened nodes of the tree and creates a + # tree visualization based on the node information. If rich is installed, + # nodes can be colored. + + print_ = print + try: + import rich + + # use rich for printing if available + print_ = rich.print # type: ignore + except ImportError: + pass + + # start with root node + node = next(nodes_iter) + label = _get_node_label(node, **kwargs) + print_(f"{node.key}: {label}") + prev_level = node.level # should be 0 + prefix = "" + + for node in nodes_iter: + visible = _check_visibility(node.is_self_safe, node.is_safe, show=show) + if not visible: + continue + + level_diff = prev_level - node.level + if level_diff < -1: + # this would mean it is a "(great-)grandchild" node + raise ValueError( + "While constructing the tree of the object, a level difference of " + f"{level_diff} was encountered, which should not be possible, please " + "report the issue here: https://github.com/skops-dev/skops/issues" + ) + + # Level diff of -1 means that this node is a child of the previous node. + # E.g. if the current level if 4 and the previous level was 3, the + # current node is a child node of the previous one. Since the prefix for + # a child node was already added, there is nothing more left to do. + + for _ in range(level_diff + 1): + # This loop is entered if the current node is at the same level as, + # or higher than, the previous node. This means the prefix has to be + # truncated according to the level difference. E.g. if the current + # level is 2 and previous level was 3, it means that we should move + # up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times. + prefix = prefix[:-4] + + print_(prefix, end="") + if node.is_last: + print_("└──", end="") + prefix += " " + else: + print_("├──", end="") + prefix += "│ " + + label = _get_node_label(node, **kwargs) + print_(f" {node.key}: {label}") + + prev_level = node.level + + +def walk_tree( + node: Node | dict[str, Node] | list[Node], + node_name: str = "root", + level: int = 0, + is_last: bool = False, +) -> Iterator[NodeInfo]: + """Visit all nodes of the tree and yield their important attributes. + + This function visits all nodes of the object tree and determines: + + - level: how nested the node is + - key: the key of the node. E.g. if the node is an attribute of an object, + the key would be the name of the attribute. + - val: the value of the node, e.g. builtins.list + - safety: whether it, and its children, are trusted + + These values are just yielded in a flat manner. This way, the consumer of + this function doesn't need to know how nodes can be nested and how safety of + a node is determined. + + Parameters + ---------- + node: :class:`skops.io._audit.Node` + The current node to visit. Children are visited recursively. + + show: "all" or "untrusted" or "trusted" + Whether to print all nodes, only untrusted nodes, or only trusted nodes. + + node_name: str (default="root") + The key to the current node. If "key_types" is encountered, it is + skipped. + + level: int (default=0) + The current level of nesting. + + is_last: bool (default=False) + Whether this is the last node among its sibling nodes. + + Yields + ------ + :class:`~NodeInfo`: + A dataclass containing the aforementioned information. + + """ + # key_types is not helpful, as it is artificially added by skops to + # circumvent the fact that json only allows keys to be strings. It is not + # useful to the user and adds a lot of noise, thus skip key_types. + if node_name == "key_types": + if isinstance(node, ListNode) and node.is_safe(): + return + raise ValueError( + "An invalid 'key_types' node was encountered, please report the issue " + "here: https://github.com/skops-dev/skops/issues" + ) + + if isinstance(node, dict): + num_nodes = len(node) + for i, (key, val) in enumerate(node.items(), start=1): + yield from walk_tree( + val, + node_name=key, + level=level, + is_last=i == num_nodes, + ) + return + + if isinstance(node, (list, tuple)): + num_nodes = len(node) + for i, val in enumerate(node, start=1): + yield from walk_tree( + val, + node_name=node_name, + level=level, + is_last=i == num_nodes, + ) + return + + # NO MATCH: RAISE ERROR + if not isinstance(node, Node): + raise TypeError( + f"Cannot deal with {type(node)}, please report the issue here: " + "https://github.com/skops-dev/skops/issues" + ) + + # YIELDING THE ACTUAL NODE INFORMATION HERE + + # Note: calling node.is_safe() on all nodes is potentially wasteful because + # it is already a recursive call, i.e. child nodes will be checked many + # times. A solution to this would be to add caching to its call. + yield NodeInfo( + level=level, + key=node_name, + val=node.format(), + is_self_safe=node.is_self_safe(), + is_safe=node.is_safe(), + is_last=is_last, + ) + + # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT + # TODO: For better security, we should check the schema if we return early, + # otherwise something nefarious could be hidden inside (however, if there + # is, the node should be marked as unsafe) + if isinstance(node, (NdArrayNode, SparseMatrixNode, FunctionNode, JsonNode)): + return + + yield from walk_tree( + node.children, + node_name=node_name, + level=level + 1, + ) + + +def visualize( + file: Path | str | bytes, + show: Literal["all", "untrusted", "trusted"] = "all", + sink: Callable[..., None] = pretty_print_tree, + **kwargs: Any, +) -> None: + """Visualize the contents of a skops file. + + Shows the schema of a skops file as a tree view. In particular, highlights + untrusted nodes. A node is considered untrusted if at least one of its child + nodes is untrusted. + + Visualizing the tree using the default visualization function requires the + ``rich`` library, which can be installed as: + + python -m pip install rich + + If passing a custom visualization function to ``sink``, ``rich`` is not + required. + + Parameters + ---------- + file: str or pathlib.Path + The file name of the object to be loaded. + + show: "all" or "untrusted" or "trusted" + Whether to print all nodes, only untrusted nodes, or only trusted nodes. + + sink: function (default=:func:`~pretty_print_tree`) + + This function should take at least two arguments, an iterator of + :class:`~NodeInfo` instances and an indicator of what to show. The + ``NodeInfo`` contains the information about the node, namely: + + - the level of nesting (int) + - the key of the node (str) + - the value of the node as a string representation (str) + - the safety of the node and its children + + The ``show`` argument is explained above. Any additional ``kwargs`` + passed to ``visualize`` will also be passed to ``sink``. + + The default sink is :func:`~pretty_print_tree`, which takes these + additional parameters: + + - tag_safe: The tag used to mark trusted nodes (default="", i.e no + tag) + - tag_unsafe: The tag used to mark untrusted nodes + (default="[UNSAFE]") + - use_colors: Whether to colorize the nodes (default=True). Colors + requires the ``rich`` package to be installed. + - color_safe: Color to use for trusted nodes (default="green") + - color_unsafe: Color to use for untrusted nodes (default="red") + - color_child_unsafe: Color to use for nodes that are trusted but + that have untrusted child ndoes (default="yellow") + + So if you don't want to have colored output, just pass + ``use_colors=False`` to ``visualize``. The colors themselves, such + as "red" and "green", refer to the standard colors used by ``rich``. + + """ + if isinstance(file, bytes): + zf = ZipFile(io.BytesIO(file), "r") + else: + zf = ZipFile(file, "r") + + with zf as zip_file: + schema = json.loads(zip_file.read("schema.json")) + tree = get_tree(schema, load_context=LoadContext(src=zip_file)) + + nodes = walk_tree(tree) + # TODO: it would be nice to print html representation if inside a notebook + sink(nodes, show, **kwargs) diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 71914b4d..24cdbc94 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -12,8 +12,8 @@ from skops.io import dumps, get_untrusted_types from skops.io._audit import Node, audit_tree, check_type, get_tree, temp_setattr -from skops.io._general import DictNode, dict_get_state -from skops.io._utils import LoadContext, SaveContext, gettype +from skops.io._general import DictNode, JsonNode, ObjectNode, dict_get_state +from skops.io._utils import LoadContext, SaveContext, get_state, gettype class CustomType: @@ -163,3 +163,28 @@ def test_complex_pipeline_untrusted_set(): untrusted = get_untrusted_types(data=dumps(clf)) type_names = [x.split(".")[-1] for x in untrusted] assert type_names == ["sqrt", "square"] + + +def test_format_object_node(): + estimator = LogisticRegression(random_state=0, solver="liblinear") + state = get_state(estimator, SaveContext(None)) + node = ObjectNode(state, LoadContext(None)) + expected = "sklearn.linear_model._logistic.LogisticRegression" + assert node.format() == expected + + +@pytest.mark.parametrize( + "inp, expected", + [ + ("hello", 'json-type("hello")'), + (123, "json-type(123)"), + (0.456, "json-type(0.456)"), + (True, "json-type(true)"), + (False, "json-type(false)"), + (None, "json-type(null)"), + ], +) +def test_format_json_node(inp, expected): + state = get_state(inp, SaveContext(None)) + node = JsonNode(state, LoadContext(None)) + assert node.format() == expected diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py new file mode 100644 index 00000000..62aa51b0 --- /dev/null +++ b/skops/io/tests/test_visualize.py @@ -0,0 +1,258 @@ +"""Tests for skops.io.visualize""" + +from unittest.mock import Mock, patch + +import pytest +import sklearn +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import FeatureUnion, Pipeline +from sklearn.preprocessing import ( + FunctionTransformer, + MinMaxScaler, + PolynomialFeatures, + StandardScaler, +) + +import skops.io as sio + + +class TestVisualizeTree: + @pytest.fixture + def simple(self): + return MinMaxScaler(feature_range=(-555, 123)) + + @pytest.fixture + def pipeline(self): + def unsafe_function(x): + return x + + # fmt: off + pipeline = Pipeline([ + ("features", FeatureUnion([ + ("scaler", StandardScaler()), + ("scaled-poly", Pipeline([ + ("polys", FeatureUnion([ + ("poly1", PolynomialFeatures()), + ("poly2", PolynomialFeatures(degree=3, include_bias=False)) + ])), + ("square-root", FunctionTransformer(unsafe_function)), + ("scale", MinMaxScaler()), + ])), + ])), + ("clf", LogisticRegression(random_state=0, solver="liblinear")), + ]).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2]) + # fmt: on + return pipeline + + @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) + def test_print_simple(self, simple, show, capsys): + file = sio.dumps(simple) + sio.visualize(file, show=show) + + # Output is the same for "all" and "trusted" because all nodes are + # trusted. Colors are not recorded by capsys. + expected = [ + "root: sklearn.preprocessing._data.MinMaxScaler", + "└── attrs: builtins.dict", + " ├── feature_range: builtins.tuple", + " │ ├── content: json-type(-555)", + " │ └── content: json-type(123)", + " ├── copy: json-type(true)", + " ├── clip: json-type(false)", + ' └── _sklearn_version: json-type("{}")'.format(sklearn.__version__), + ] + if show == "untrusted": + # since no untrusted, only show root + expected = expected[:1] + + stdout, _ = capsys.readouterr() + assert stdout.strip() == "\n".join(expected) + + def test_print_pipeline(self, pipeline, capsys): + file = sio.dumps(pipeline) + sio.visualize(file) + + # no point in checking the whole output with > 120 lines + expected_start = [ + "root: sklearn.pipeline.Pipeline", + "└── attrs: builtins.dict", + " ├── steps: builtins.list", + " │ ├── content: builtins.tuple", + ' │ │ ├── content: json-type("features")', + ] + expected_end = [ + " ├── memory: json-type(null)", + " ├── verbose: json-type(false)", + ' └── _sklearn_version: json-type("{}")'.format(sklearn.__version__), + ] + + stdout, _ = capsys.readouterr() + assert stdout.startswith("\n".join(expected_start)) + assert stdout.rstrip().endswith("\n".join(expected_end)) + + def test_unsafe_nodes(self, pipeline): + file = sio.dumps(pipeline) + nodes = [] + + def sink(nodes_iter, *args, **kwargs): + nodes.extend(nodes_iter) + + sio.visualize(file, sink=sink) + nodes_self_unsafe = [node for node in nodes if not node.is_self_safe] + nodes_unsafe = [node for node in nodes if not node.is_safe] + + # there are currently 2 unsafe nodes, a numpy int and the custom + # functions. The former might be considered safe in the future, in which + # case this test needs to be changed. + assert len(nodes_self_unsafe) == 2 + assert nodes_self_unsafe[0].val == "numpy.int64" + assert nodes_self_unsafe[1].val == "test_visualize.function" + + # it's not easy to test the number of indirectly unsafe nodes, because + # it will depend on the nesting structure; we can only be sure that it's + # more than 2, and one of them should be the FunctionTransformer + assert len(nodes_unsafe) > 2 + assert any("FunctionTransformer" in node.val for node in nodes_unsafe) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + {"use_colors": False}, + {"tag_unsafe": "", "color_unsafe": "blue"}, + ], + ) + def test_custom_print_config_passed_to_sink(self, simple, kwargs): + # check that arguments are passed to sink + def my_sink(nodes_iter, show, **sink_kwargs): + for key, val in kwargs.items(): + assert sink_kwargs[key] == val + + file = sio.dumps(simple) + sio.visualize(file, sink=my_sink, **kwargs) + + def test_custom_tags(self, simple, capsys): + class UnsafeType: + pass + + simple.copy = UnsafeType + + file = sio.dumps(simple) + sio.visualize(file, tag_safe="NICE", tag_unsafe="OHNO") + expected = [ + "root: sklearn.preprocessing._data.MinMaxScaler NICE", + "└── attrs: builtins.dict NICE", + " ├── feature_range: builtins.tuple NICE", + " │ ├── content: json-type(-555) NICE", + " │ └── content: json-type(123) NICE", + " ├── copy: test_visualize.UnsafeType OHNO", + " ├── clip: json-type(false) NICE", + ' └── _sklearn_version: json-type("{}") NICE'.format( + sklearn.__version__ + ), + ] + + stdout, _ = capsys.readouterr() + assert stdout.strip() == "\n".join(expected) + + def test_custom_colors(self, simple): + # test that custom colors are used in node representation, requires rich + # to work + pytest.importorskip("rich") + + class UnsafeType: + pass + + simple.copy = UnsafeType + file = sio.dumps(simple) + + # Colors are not recorded by capsys, so we cannot use it and must mock + # printing + mock_print = Mock() + with patch("rich.print", mock_print): + sio.visualize( + file, + color_safe="black", + color_unsafe="cyan", + color_child_unsafe="orange3", + ) + + mock_print.assert_called() + + calls = mock_print.call_args_list + # The root node is indirectly unsafe through child + assert ( + calls[0].args[0] + == "root: [orange3]sklearn.preprocessing._data.MinMaxScaler" + ) + # 'feature_range' is safe + assert calls[6].args[0] == " feature_range: [black]builtins.tuple" + # 'copy' is unsafe + assert calls[15].args[0] == " copy: [cyan]test_visualize.UnsafeType [UNSAFE]" + + @pytest.mark.usefixtures("rich_not_installed") + def test_no_colors_if_rich_not_installed(self, simple): + # this test is similar to the previous one, except that we test that the + # colors are *not* used if rich is not installed + file = sio.dumps(simple) + + # don't use capsys, because it wouldn't capture the colors, thus need to + # use mock + mock_print = Mock() + with patch("builtins.print", mock_print): + sio.visualize( + file, + color_safe="black", + color_unsafe="cyan", + color_child_unsafe="orange3", + ) + mock_print.assert_called() + + # check that none of the colors are being used + colors = ("black", "cyan", "orange3") + for call in mock_print.call_args_list: + for color in colors: + assert color not in call.args[0] + + def test_no_colors_if_use_colors_false(self, simple): + # this test is similar to the previous one, except that we test that the + # colors are *not* used, even if rich is installed, when passing + # use_colors=False + file = sio.dumps(simple) + + # don't use capsys, because it wouldn't capture the colors, thus need to + # use mock + mock_print = Mock() + with patch("rich.print", mock_print): + sio.visualize( + file, + color_safe="black", + color_unsafe="cyan", + color_child_unsafe="orange3", + use_colors=False, + ) + mock_print.assert_called() + + # check that none of the colors are being used + colors = ("black", "cyan", "orange3") + for call in mock_print.call_args_list: + for color in colors: + assert color not in call.args[0] + + def test_from_file(self, simple, tmp_path, capsys): + f_name = tmp_path / "estimator.skops" + sio.dump(simple, f_name) + sio.visualize(f_name) + + expected = [ + "root: sklearn.preprocessing._data.MinMaxScaler", + "└── attrs: builtins.dict", + " ├── feature_range: builtins.tuple", + " │ ├── content: json-type(-555)", + " │ └── content: json-type(123)", + " ├── copy: json-type(true)", + " ├── clip: json-type(false)", + ' └── _sklearn_version: json-type("{}")'.format(sklearn.__version__), + ] + stdout, _ = capsys.readouterr() + assert stdout.strip() == "\n".join(expected)