From 4e3dba24086d807d477a65de8e6978793667a622 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 10 Mar 2023 16:52:52 +0100 Subject: [PATCH 01/12] [WIP] Function to visualize skops files Resolves #301 This is not finished, just a basis for discussion. --- skops/io/_scipy.py | 4 +- skops/io/_visualize.py | 212 +++++++++++++++++++++++++++++++ skops/io/tests/test_visualize.py | 91 +++++++++++++ 3 files changed, 305 insertions(+), 2 deletions(-) create mode 100644 skops/io/_visualize.py create mode 100644 skops/io/tests/test_visualize.py diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index cbf5c1d1..1c494a49 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -40,9 +40,9 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - type = state["type"] + self.type = state["type"] self.trusted = self._get_trusted(trusted, [spmatrix]) - if type != "scipy": + if self.type != "scipy": raise TypeError( f"Cannot load object of type {self.module_name}.{self.class_name}" ) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py new file mode 100644 index 00000000..d949511e --- /dev/null +++ b/skops/io/_visualize.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import json +from functools import singledispatch +from pathlib import Path +from typing import Callable, Literal, Sequence +from zipfile import ZipFile + +from ._audit import Node, get_tree +from ._general import FunctionNode, JsonNode +from ._numpy import NdArrayNode +from ._scipy import SparseMatrixNode +from ._utils import LoadContext + +PrintFn = Callable[ + [Node, str, int, bool | Sequence[str], Literal["all", "untrusted", "trusted"]], None +] + + +def _check_should_print( + node: Node, + trusted: bool | Sequence[str], + show: Literal["all", "untrusted", "trusted"], +) -> tuple[bool, bool]: + if trusted is True: + is_safe = True + elif trusted is False: + is_safe = not node.get_unsafe_set() + else: + is_safe = not (node.get_unsafe_set() - set(trusted)) + + should_print = ( + (show == "all") + or ((show == "untrusted") and (not is_safe)) + or ((show == "trusted") and is_safe) + ) + return should_print, is_safe + + +def _print_node( + node: Node, + name: str, + key: str, + level: int, + trusted: bool | Sequence[str], + show: Literal["all", "untrusted", "trusted"], +): + should_print, is_safe = _check_should_print(node, trusted=trusted, show=show) + if not should_print: + return + + prefix = "" + if level > 0: + prefix += "├-" + if level > 1: + prefix += "--" * (level - 1) + + text = f"{prefix}{key}: {name}{'' if is_safe else ' [UNSAFE]'}" + print(text) + + +# use singledispatch so that we can register specialized visualization functions +@singledispatch +def print_node( + node: Node, + key: str, + level: int, + trusted: bool | Sequence[str], + show: Literal["all", "untrusted", "trusted"], +): + name = f"{node.module_name}.{node.class_name}" + _print_node(node, name=name, key=key, level=level, trusted=trusted, show=show) + + +@print_node.register +def _print_function_node( + node: FunctionNode, + key: str, + level: int, + trusted: bool | Sequence[str], + show: Literal["all", "untrusted", "trusted"], +): + # if a FunctionNode, children are not visited, but safety should still be checked + child = node.children["content"] + fn_name = f"{child['module_path']}.{child['function']}" + name = f"{node.module_name}.{node.class_name} => {fn_name}" + _print_node(node, name=name, key=key, level=level, trusted=trusted, show=show) + + +@print_node.register +def _print_json_node( + node: JsonNode, + key: str, + level: int, + trusted: bool | Sequence[str], + show: Literal["all", "untrusted", "trusted"], +): + name = f"json-type({node.content})" + return _print_node( + node, name=name, key=key, level=level, trusted=trusted, show=show + ) + + +def _visualize_tree( + node: Node | dict[str, Node] | Sequence[Node], + trusted: bool | Sequence[str] = False, + show: Literal["all", "untrusted", "trusted"] = "all", + node_name: str = "root", + level: int = 0, + sink: PrintFn = print_node, +) -> None: + # helper function to pretty-print the nodes + if node_name == "key_types": + # _check_key_types_schema(node) + return + + # COMPOSITE TYPES: CHECK ALL ITEMS + if isinstance(node, dict): + for key, val in node.items(): + _visualize_tree( + val, node_name=key, level=level, trusted=trusted, show=show, sink=sink + ) + return + + if isinstance(node, (list, tuple)): + for val in node: + _visualize_tree( + val, + node_name=node_name, + level=level, + trusted=trusted, + show=show, + sink=sink, + ) + return + + # NO MATCH: RAISE ERROR + if not isinstance(node, Node): + raise TypeError(f"{type(node)}") + + # TRIGGER SIDE-EFFECT + sink(node, node_name, level, trusted, show) + + # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT + if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): + # _check_array_schema(node) + return + + if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type == "json"): + # _check_array_json_schema(node) + return + + if isinstance(node, FunctionNode): + # _check_function_schema(node) + pass + return + + if isinstance(node, JsonNode): + # _check_json_schema(node) + pass + + # RECURSE + _visualize_tree( + node.children, + node_name=node_name, + level=level + 1, + trusted=trusted, + show=show, + sink=sink, + ) + + +def visualize_tree( + file: Path | str, # TODO: from bytes + trusted: bool | Sequence[str] = False, + show: Literal["all", "untrusted", "trusted"] = "all", + sink: PrintFn = print_node, +) -> None: + """Visualize the contents of a skops file. + + Shows the schema of a skops file as a tree view. In particular, highlights + untrusted nodes. A node is considered untrusted if at least one of its child + nodes is untrusted. + + Parameters + ---------- + file: str or pathlib.Path + The file name of the object to be loaded. + + trusted: bool, or list of str, default=False + If ``True``, the object will be loaded without any security checks. If + ``False``, the object will be loaded only if there are only trusted + objects in the dumped file. If a list of strings, the object will be + loaded only if there are only trusted objects and objects of types + listed in ``trusted`` are in the dumped file. + + show: "all" or "untrusted" or "trusted" + Whether to print all nodes, only untrusted nodes, or only trusted nodes. + + sink: function + + Function used to print the schema. By default, this generates a tree + view and prints it to stdout. If you want to do something else with the + output, e.g. log it to a file, pass a function here that does that. The + signature of this function should be ``Callable[[Node, str, int, bool | + Sequence[str], Literal["all", "untrusted", "trusted"]], None]``. + + """ + with ZipFile(file, "r") as zip_file: + schema = json.loads(zip_file.read("schema.json")) + tree = get_tree(schema, load_context=LoadContext(src=zip_file)) + _visualize_tree(tree, trusted=trusted, show=show, sink=sink) diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py new file mode 100644 index 00000000..b0cf0736 --- /dev/null +++ b/skops/io/tests/test_visualize.py @@ -0,0 +1,91 @@ +import numpy as np +import pytest +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import FeatureUnion, Pipeline +from sklearn.preprocessing import ( + FunctionTransformer, + MinMaxScaler, + PolynomialFeatures, + StandardScaler, +) + +import skops.io as sio +from skops.io._visualize import _check_should_print, visualize_tree + + +class TestVisualizeTree: + @pytest.fixture + def simple(self): + return MinMaxScaler(feature_range=(-555, 123)) + + @pytest.fixture + def simple_file(self, simple, tmp_path): + f_name = tmp_path / "estimator.skops" + sio.dump(simple, f_name) + return f_name + + @pytest.fixture + def pipeline(self): + # fmt: off + pipeline = Pipeline([ + ("features", FeatureUnion([ + ("scaler", StandardScaler()), + ("scaled-poly", Pipeline([ + ("polys", FeatureUnion([ + ("poly1", PolynomialFeatures()), + ("poly2", PolynomialFeatures(degree=3, include_bias=False)) + ])), + ("square-root", FunctionTransformer(np.sqrt)), + ("scale", MinMaxScaler()), + ])), + ])), + ("clf", LogisticRegression(random_state=0, solver="liblinear")), + ]).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2]) + # fmt: on + return pipeline + + @pytest.fixture + def pipeline_file(self, pipeline, tmp_path): + f_name = tmp_path / "estimator.skops" + sio.dump(pipeline, f_name) + return f_name + + @pytest.fixture + def side_effect_and_contents(self): + # This side effect collects the contents of what would normally be + # printed. That way, we can test more precisely than just capturing + # stdout and inspecting strings. + contents = [] + + def side_effect(node, key, level, trusted, show): + should_print, _ = _check_should_print(node, trusted, show) + if should_print: + contents.append((node, key, level, trusted, show)) + + return side_effect, contents + + @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) + def test_print_simple(self, simple_file, show): + visualize_tree(simple_file, show=show) + + @pytest.mark.parametrize( + "show_tell", [("all", 8), ("trusted", 8), ("untrusted", 0)] + ) + def test_inspect_simple(self, simple_file, side_effect_and_contents, show_tell): + side_effect, contents = side_effect_and_contents + show, expected_elements = show_tell + visualize_tree(simple_file, sink=side_effect, show=show) + assert len(contents) == expected_elements + + @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) + def test_print_pipeline(self, pipeline_file, show): + visualize_tree(pipeline_file, show=show) + + @pytest.mark.parametrize( + "show_tell", [("all", 129), ("trusted", 110), ("untrusted", 19)] + ) + def test_inspect_pipeline(self, pipeline_file, side_effect_and_contents, show_tell): + side_effect, contents = side_effect_and_contents + show, expected_elements = show_tell + visualize_tree(pipeline_file, sink=side_effect, show=show) + assert len(contents) == expected_elements From 5a955e549447597b181ac1fd318cbae56976beaf Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 14 Mar 2023 14:44:04 +0100 Subject: [PATCH 02/12] Only add [UNSAFE] tag if node itself is unsafe Don't add it to parent nodes. --- skops/io/_visualize.py | 20 ++++++++++++-------- skops/io/tests/test_visualize.py | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index d949511e..ed33f3d0 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -21,20 +21,23 @@ def _check_should_print( node: Node, trusted: bool | Sequence[str], show: Literal["all", "untrusted", "trusted"], -) -> tuple[bool, bool]: +) -> bool: + # Note: this is very inefficient, because get_unsafe_set will be called many + # times on the same node (since parents recursively call children) but maybe + # that's acceptable for this context. If not, caching could be an option. if trusted is True: - is_safe = True + node_and_children_are_safe = True elif trusted is False: - is_safe = not node.get_unsafe_set() + node_and_children_are_safe = not node.get_unsafe_set() else: - is_safe = not (node.get_unsafe_set() - set(trusted)) + node_and_children_are_safe = not (node.get_unsafe_set() - set(trusted)) should_print = ( (show == "all") - or ((show == "untrusted") and (not is_safe)) - or ((show == "trusted") and is_safe) + or ((show == "untrusted") and (not node_and_children_are_safe)) + or ((show == "trusted") and node_and_children_are_safe) ) - return should_print, is_safe + return should_print def _print_node( @@ -45,7 +48,8 @@ def _print_node( trusted: bool | Sequence[str], show: Literal["all", "untrusted", "trusted"], ): - should_print, is_safe = _check_should_print(node, trusted=trusted, show=show) + is_safe = node.is_self_safe() + should_print = _check_should_print(node, trusted=trusted, show=show) if not should_print: return diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index b0cf0736..e7d27cad 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -58,7 +58,7 @@ def side_effect_and_contents(self): contents = [] def side_effect(node, key, level, trusted, show): - should_print, _ = _check_should_print(node, trusted, show) + should_print = _check_should_print(node, trusted, show) if should_print: contents.append((node, key, level, trusted, show)) From e753d25b4b235b17656c0fbb3ff32382f5c3122d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 14 Mar 2023 16:07:09 +0100 Subject: [PATCH 03/12] Add print config, including colors --- skops/io/_visualize.py | 195 ++++++++++++++++++++++++++++--- skops/io/tests/test_visualize.py | 14 ++- 2 files changed, 188 insertions(+), 21 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index ed33f3d0..e9e0b64d 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -17,11 +17,44 @@ ] +class PrintConfig: + # fmt: off + print_fn = print + template = "{prefix}{key}: {name}{tag}" + + tag_safe = "" # noqa: E222 + tag_unsafe = " [UNSAFE]" + + line_start = "├─" + line = "──" # noqa: E222 + + use_colors = True + color_safe = '\033[32m' # green # noqa: E222 + color_unsafe = '\033[31m' # red # noqa: E222 + color_child_unsafe = '\033[33m' # yellow + color_end = '\033[0m' # noqa: E222 + # fmt: on + + +print_config = PrintConfig() + + def _check_should_print( node: Node, - trusted: bool | Sequence[str], + node_is_safe: bool, + node_and_children_are_safe: bool, show: Literal["all", "untrusted", "trusted"], ) -> bool: + if show == "all": + should_print = True + elif show == "untrusted": + should_print = not node_and_children_are_safe + else: # only trusted + should_print = node_is_safe + return should_print + + +def _check_node_and_children_safe(node: Node, trusted: bool | Sequence[str]) -> bool: # Note: this is very inefficient, because get_unsafe_set will be called many # times on the same node (since parents recursively call children) but maybe # that's acceptable for this context. If not, caching could be an option. @@ -31,13 +64,7 @@ def _check_should_print( node_and_children_are_safe = not node.get_unsafe_set() else: node_and_children_are_safe = not (node.get_unsafe_set() - set(trusted)) - - should_print = ( - (show == "all") - or ((show == "untrusted") and (not node_and_children_are_safe)) - or ((show == "trusted") and node_and_children_are_safe) - ) - return should_print + return node_and_children_are_safe def _print_node( @@ -47,20 +74,43 @@ def _print_node( level: int, trusted: bool | Sequence[str], show: Literal["all", "untrusted", "trusted"], + config: PrintConfig = print_config, ): - is_safe = node.is_self_safe() - should_print = _check_should_print(node, trusted=trusted, show=show) + # determine if node itself and if its children are safe + node_is_safe = node.is_self_safe() + node_and_children_are_safe = _check_node_and_children_safe(node, trusted) + + should_print = _check_should_print( + node, + node_is_safe=node_is_safe, + node_and_children_are_safe=node_and_children_are_safe, + show=show, + ) if not should_print: return - prefix = "" - if level > 0: - prefix += "├-" - if level > 1: - prefix += "--" * (level - 1) + tag = config.tag_safe if node_is_safe else config.tag_unsafe + + prefix = " " * (level - 1) + config.line_start - text = f"{prefix}{key}: {name}{'' if is_safe else ' [UNSAFE]'}" - print(text) + # prefix = "" + # if level > 0: + # prefix += config.line_start + # if level > 1: + # prefix += config.line * (level - 1) + + template = config.template + if config.use_colors: + if node_and_children_are_safe: + color = config.color_safe + elif node_is_safe: + color = config.color_child_unsafe + else: + color = config.color_unsafe + name = f"{color}{name}{config.color_end}" + + text = template.format(prefix=prefix, key=key, name=name, tag=tag) + config.print_fn(text) # use singledispatch so that we can register specialized visualization functions @@ -156,7 +206,6 @@ def _visualize_tree( if isinstance(node, FunctionNode): # _check_function_schema(node) - pass return if isinstance(node, JsonNode): @@ -214,3 +263,113 @@ def visualize_tree( schema = json.loads(zip_file.read("schema.json")) tree = get_tree(schema, load_context=LoadContext(src=zip_file)) _visualize_tree(tree, trusted=trusted, show=show, sink=sink) + + +# def _walk_tree( +# node: Node | dict[str, Node] | Sequence[Node], +# trusted: bool | Sequence[str] = False, +# show: Literal["all", "untrusted", "trusted"] = "all", +# node_name: str = "root", +# level: int = 0, +# sink: PrintFn = print_node, +# ) -> Iterator[Any]: # TODO +# # helper function to pretty-print the nodes +# if node_name == "key_types": +# # _check_key_types_schema(node) +# return + +# # COMPOSITE TYPES: CHECK ALL ITEMS +# if isinstance(node, dict): +# for key, val in node.items(): +# yield from _walk_tree( +# val, node_name=key, level=level, trusted=trusted, show=show, sink=sink +# ) +# return + +# if isinstance(node, (list, tuple)): +# for val in node: +# yield from _walk_tree( +# val, +# node_name=node_name, +# level=level, +# trusted=trusted, +# show=show, +# sink=sink, +# ) +# return + +# # NO MATCH: RAISE ERROR +# if not isinstance(node, Node): +# raise TypeError(f"{type(node)}") + +# # TRIGGER SIDE-EFFECT +# sink(node, node_name, level, trusted, show) + +# # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT +# if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): +# # _check_array_schema(node) +# return + +# if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type == "json"): +# # _check_array_json_schema(node) +# return + +# if isinstance(node, FunctionNode): +# # _check_function_schema(node) +# return + +# if isinstance(node, JsonNode): +# # _check_json_schema(node) +# pass + +# # RECURSE +# yield from _walk_tree( +# node.children, +# node_name=node_name, +# level=level + 1, +# trusted=trusted, +# show=show, +# sink=sink, +# ) + + +# def visualize_tree( +# file: Path | str, # TODO: from bytes +# trusted: bool | Sequence[str] = False, +# show: Literal["all", "untrusted", "trusted"] = "all", +# sink: PrintFn = print_node, +# ) -> None: +# """Visualize the contents of a skops file. + +# Shows the schema of a skops file as a tree view. In particular, highlights +# untrusted nodes. A node is considered untrusted if at least one of its child +# nodes is untrusted. + +# Parameters +# ---------- +# file: str or pathlib.Path +# The file name of the object to be loaded. + +# trusted: bool, or list of str, default=False +# If ``True``, the object will be loaded without any security checks. If +# ``False``, the object will be loaded only if there are only trusted +# objects in the dumped file. If a list of strings, the object will be +# loaded only if there are only trusted objects and objects of types +# listed in ``trusted`` are in the dumped file. + +# show: "all" or "untrusted" or "trusted" +# Whether to print all nodes, only untrusted nodes, or only trusted nodes. + +# sink: function + +# Function used to print the schema. By default, this generates a tree +# view and prints it to stdout. If you want to do something else with the +# output, e.g. log it to a file, pass a function here that does that. The +# signature of this function should be ``Callable[[Node, str, int, bool | +# Sequence[str], Literal["all", "untrusted", "trusted"]], None]``. + +# """ +# with ZipFile(file, "r") as zip_file: +# schema = json.loads(zip_file.read("schema.json")) +# tree = get_tree(schema, load_context=LoadContext(src=zip_file)) +# _visualize_tree(tree, trusted=trusted, show=show, sink=sink) diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index e7d27cad..34744548 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -10,7 +10,11 @@ ) import skops.io as sio -from skops.io._visualize import _check_should_print, visualize_tree +from skops.io._visualize import ( + _check_node_and_children_safe, + _check_should_print, + visualize_tree, +) class TestVisualizeTree: @@ -58,7 +62,11 @@ def side_effect_and_contents(self): contents = [] def side_effect(node, key, level, trusted, show): - should_print = _check_should_print(node, trusted, show) + node_is_safe = node.is_self_safe() + node_and_children_are_safe = _check_node_and_children_safe(node, trusted) + should_print = _check_should_print( + node, node_is_safe, node_and_children_are_safe, show + ) if should_print: contents.append((node, key, level, trusted, show)) @@ -82,7 +90,7 @@ def test_print_pipeline(self, pipeline_file, show): visualize_tree(pipeline_file, show=show) @pytest.mark.parametrize( - "show_tell", [("all", 129), ("trusted", 110), ("untrusted", 19)] + "show_tell", [("all", 129), ("trusted", 127), ("untrusted", 19)] ) def test_inspect_pipeline(self, pipeline_file, side_effect_and_contents, show_tell): side_effect, contents = side_effect_and_contents From 5960d8195c89e87040513f4d9a3f5895fc20713f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 14 Mar 2023 17:30:11 +0100 Subject: [PATCH 04/12] Major refactor: better separation of concerns Code should be much more straightforward now. Moreover, it is refactored so that the sink function now gets an iterator of all nodes, not just of a single node. This is better because printing of nodes can now happen in context. Just as an example, a node could now be printed differently if it is the first node of its level. --- skops/io/_visualize.py | 333 ++++++++++--------------------- skops/io/tests/test_visualize.py | 42 +--- 2 files changed, 121 insertions(+), 254 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index e9e0b64d..76c35c3f 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -1,9 +1,10 @@ from __future__ import annotations import json +from dataclasses import dataclass from functools import singledispatch from pathlib import Path -from typing import Callable, Literal, Sequence +from typing import Callable, Iterator, Literal, Sequence from zipfile import ZipFile from ._audit import Node, get_tree @@ -12,34 +13,51 @@ from ._scipy import SparseMatrixNode from ._utils import LoadContext -PrintFn = Callable[ - [Node, str, int, bool | Sequence[str], Literal["all", "untrusted", "trusted"]], None -] - +@dataclass class PrintConfig: # fmt: off - print_fn = print - template = "{prefix}{key}: {name}{tag}" - - tag_safe = "" # noqa: E222 - tag_unsafe = " [UNSAFE]" + tag_safe: str = "" # noqa: E222 + tag_unsafe: str = " [UNSAFE]" - line_start = "├─" - line = "──" # noqa: E222 + line_start: str = "├─" + line: str = "──" # noqa: E222 - use_colors = True - color_safe = '\033[32m' # green # noqa: E222 - color_unsafe = '\033[31m' # red # noqa: E222 - color_child_unsafe = '\033[33m' # yellow - color_end = '\033[0m' # noqa: E222 + use_colors: bool = True + color_safe: str = '\033[32m' # green # noqa: E222 + color_unsafe: str = '\033[31m' # red # noqa: E222 + color_child_unsafe: str = '\033[33m' # yellow + color_end: str = '\033[0m' # noqa: E222 # fmt: on print_config = PrintConfig() -def _check_should_print( +@dataclass +class FormattedNode: + level: int + key: str # the key to the node + val: str # the value of the node + visible: bool # whether it should be shown + + +def pretty_print_tree( + formatted_nodes: Iterator[FormattedNode], config: PrintConfig +) -> None: + # TODO: the "tree" lines could be made prettier since all nodes are known + # here + for formatted_node in formatted_nodes: + if not formatted_node.visible: + continue + + line = print_config.line_start + line += (formatted_node.level - 1) * print_config.line + line += f"{formatted_node.key}: {formatted_node.val}" + print(line) + + +def _check_visibility( node: Node, node_is_safe: bool, node_and_children_are_safe: bool, @@ -67,102 +85,41 @@ def _check_node_and_children_safe(node: Node, trusted: bool | Sequence[str]) -> return node_and_children_are_safe -def _print_node( - node: Node, - name: str, - key: str, - level: int, - trusted: bool | Sequence[str], - show: Literal["all", "untrusted", "trusted"], - config: PrintConfig = print_config, -): - # determine if node itself and if its children are safe - node_is_safe = node.is_self_safe() - node_and_children_are_safe = _check_node_and_children_safe(node, trusted) - - should_print = _check_should_print( - node, - node_is_safe=node_is_safe, - node_and_children_are_safe=node_and_children_are_safe, - show=show, - ) - if not should_print: - return - - tag = config.tag_safe if node_is_safe else config.tag_unsafe - - prefix = " " * (level - 1) + config.line_start - - # prefix = "" - # if level > 0: - # prefix += config.line_start - # if level > 1: - # prefix += config.line * (level - 1) - - template = config.template - if config.use_colors: - if node_and_children_are_safe: - color = config.color_safe - elif node_is_safe: - color = config.color_child_unsafe - else: - color = config.color_unsafe - name = f"{color}{name}{config.color_end}" - - text = template.format(prefix=prefix, key=key, name=name, tag=tag) - config.print_fn(text) - - # use singledispatch so that we can register specialized visualization functions @singledispatch -def print_node( - node: Node, - key: str, - level: int, - trusted: bool | Sequence[str], - show: Literal["all", "untrusted", "trusted"], -): - name = f"{node.module_name}.{node.class_name}" - _print_node(node, name=name, key=key, level=level, trusted=trusted, show=show) +def format_node(node: Node) -> str: + """Format the name of the node. + By default, this is just the fully qualified name of the class, e.g. + ``"sklearn.preprocessing._data.MinMaxScaler"``. But for some types of nodes, + having a more specific output is desirable. These node types can be + registered with this function. -@print_node.register -def _print_function_node( - node: FunctionNode, - key: str, - level: int, - trusted: bool | Sequence[str], - show: Literal["all", "untrusted", "trusted"], -): + """ + return f"{node.module_name}.{node.class_name}" + + +@format_node.register +def _format_function_node(node: FunctionNode) -> str: # if a FunctionNode, children are not visited, but safety should still be checked child = node.children["content"] fn_name = f"{child['module_path']}.{child['function']}" - name = f"{node.module_name}.{node.class_name} => {fn_name}" - _print_node(node, name=name, key=key, level=level, trusted=trusted, show=show) + return f"{node.module_name}.{node.class_name} => {fn_name}" -@print_node.register -def _print_json_node( - node: JsonNode, - key: str, - level: int, - trusted: bool | Sequence[str], - show: Literal["all", "untrusted", "trusted"], -): - name = f"json-type({node.content})" - return _print_node( - node, name=name, key=key, level=level, trusted=trusted, show=show - ) +@format_node.register +def _format_json_node(node: JsonNode) -> str: + return f"json-type({node.content})" -def _visualize_tree( +def walk_tree( node: Node | dict[str, Node] | Sequence[Node], trusted: bool | Sequence[str] = False, show: Literal["all", "untrusted", "trusted"] = "all", node_name: str = "root", level: int = 0, - sink: PrintFn = print_node, -) -> None: + config: PrintConfig = print_config, +) -> Iterator[FormattedNode]: # helper function to pretty-print the nodes if node_name == "key_types": # _check_key_types_schema(node) @@ -171,20 +128,25 @@ def _visualize_tree( # COMPOSITE TYPES: CHECK ALL ITEMS if isinstance(node, dict): for key, val in node.items(): - _visualize_tree( - val, node_name=key, level=level, trusted=trusted, show=show, sink=sink + yield from walk_tree( + val, + node_name=key, + level=level, + trusted=trusted, + show=show, + config=config, ) return if isinstance(node, (list, tuple)): for val in node: - _visualize_tree( + yield from walk_tree( val, node_name=node_name, level=level, trusted=trusted, show=show, - sink=sink, + config=config, ) return @@ -192,8 +154,31 @@ def _visualize_tree( if not isinstance(node, Node): raise TypeError(f"{type(node)}") - # TRIGGER SIDE-EFFECT - sink(node, node_name, level, trusted, show) + # THE ACTUAL FORMATTING HAPPENS HERE + node_is_safe = node.is_self_safe() + node_and_children_are_safe = _check_node_and_children_safe(node, trusted) + visible = _check_visibility( + node, + node_is_safe=node_is_safe, + node_and_children_are_safe=node_and_children_are_safe, + show=show, + ) + + node_val = format_node(node) + tag = config.tag_safe if node_is_safe else config.tag_unsafe + if tag: + node_val += f" {tag}" + + if config.use_colors: + if node_and_children_are_safe: + color = config.color_safe + elif node_is_safe: + color = config.color_child_unsafe + else: + color = config.color_unsafe + node_val = f"{color}{node_val}{config.color_end}" + + yield FormattedNode(level=level, key=node_name, val=node_val, visible=visible) # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): @@ -213,13 +198,13 @@ def _visualize_tree( pass # RECURSE - _visualize_tree( + yield from walk_tree( node.children, node_name=node_name, level=level + 1, trusted=trusted, show=show, - sink=sink, + config=config, ) @@ -227,7 +212,8 @@ def visualize_tree( file: Path | str, # TODO: from bytes trusted: bool | Sequence[str] = False, show: Literal["all", "untrusted", "trusted"] = "all", - sink: PrintFn = print_node, + sink: Callable[[Iterator[FormattedNode], PrintConfig], None] = pretty_print_tree, + print_config: PrintConfig = print_config, ) -> None: """Visualize the contents of a skops file. @@ -251,125 +237,26 @@ def visualize_tree( Whether to print all nodes, only untrusted nodes, or only trusted nodes. sink: function + This function should take two arguments, an iterator of + ``FormattedNode`` and a ``PrintConfig``. The ``FormattedNode`` contains + the information about the node, namely: + + - the level of nesting (int) + - the key of the node (str) + - the value of the node as a string representation (str) + - the visibility of the node, depending on the ``show`` argument (bool) - Function used to print the schema. By default, this generates a tree - view and prints it to stdout. If you want to do something else with the - output, e.g. log it to a file, pass a function here that does that. The - signature of this function should be ``Callable[[Node, str, int, bool | - Sequence[str], Literal["all", "untrusted", "trusted"]], None]``. + The second argument is the print config (see description of next argument). + + print_config: :class:`~PrintConfig` + The ``PrintConfig`` is a simple object with attributes that determine + how the node should be visualized, e.g. the ``use_colors`` attribute + determines if colors should be used. """ with ZipFile(file, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) tree = get_tree(schema, load_context=LoadContext(src=zip_file)) - _visualize_tree(tree, trusted=trusted, show=show, sink=sink) - - -# def _walk_tree( -# node: Node | dict[str, Node] | Sequence[Node], -# trusted: bool | Sequence[str] = False, -# show: Literal["all", "untrusted", "trusted"] = "all", -# node_name: str = "root", -# level: int = 0, -# sink: PrintFn = print_node, -# ) -> Iterator[Any]: # TODO -# # helper function to pretty-print the nodes -# if node_name == "key_types": -# # _check_key_types_schema(node) -# return - -# # COMPOSITE TYPES: CHECK ALL ITEMS -# if isinstance(node, dict): -# for key, val in node.items(): -# yield from _walk_tree( -# val, node_name=key, level=level, trusted=trusted, show=show, sink=sink -# ) -# return - -# if isinstance(node, (list, tuple)): -# for val in node: -# yield from _walk_tree( -# val, -# node_name=node_name, -# level=level, -# trusted=trusted, -# show=show, -# sink=sink, -# ) -# return - -# # NO MATCH: RAISE ERROR -# if not isinstance(node, Node): -# raise TypeError(f"{type(node)}") - -# # TRIGGER SIDE-EFFECT -# sink(node, node_name, level, trusted, show) - -# # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT -# if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): -# # _check_array_schema(node) -# return - -# if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type == "json"): -# # _check_array_json_schema(node) -# return - -# if isinstance(node, FunctionNode): -# # _check_function_schema(node) -# return - -# if isinstance(node, JsonNode): -# # _check_json_schema(node) -# pass - -# # RECURSE -# yield from _walk_tree( -# node.children, -# node_name=node_name, -# level=level + 1, -# trusted=trusted, -# show=show, -# sink=sink, -# ) - - -# def visualize_tree( -# file: Path | str, # TODO: from bytes -# trusted: bool | Sequence[str] = False, -# show: Literal["all", "untrusted", "trusted"] = "all", -# sink: PrintFn = print_node, -# ) -> None: -# """Visualize the contents of a skops file. - -# Shows the schema of a skops file as a tree view. In particular, highlights -# untrusted nodes. A node is considered untrusted if at least one of its child -# nodes is untrusted. - -# Parameters -# ---------- -# file: str or pathlib.Path -# The file name of the object to be loaded. - -# trusted: bool, or list of str, default=False -# If ``True``, the object will be loaded without any security checks. If -# ``False``, the object will be loaded only if there are only trusted -# objects in the dumped file. If a list of strings, the object will be -# loaded only if there are only trusted objects and objects of types -# listed in ``trusted`` are in the dumped file. - -# show: "all" or "untrusted" or "trusted" -# Whether to print all nodes, only untrusted nodes, or only trusted nodes. - -# sink: function - -# Function used to print the schema. By default, this generates a tree -# view and prints it to stdout. If you want to do something else with the -# output, e.g. log it to a file, pass a function here that does that. The -# signature of this function should be ``Callable[[Node, str, int, bool | -# Sequence[str], Literal["all", "untrusted", "trusted"]], None]``. - -# """ -# with ZipFile(file, "r") as zip_file: -# schema = json.loads(zip_file.read("schema.json")) -# tree = get_tree(schema, load_context=LoadContext(src=zip_file)) -# _visualize_tree(tree, trusted=trusted, show=show, sink=sink) + + nodes = walk_tree(tree, trusted=trusted, show=show, config=print_config) + sink(nodes, print_config) diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 34744548..39b7104f 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -10,11 +10,7 @@ ) import skops.io as sio -from skops.io._visualize import ( - _check_node_and_children_safe, - _check_should_print, - visualize_tree, -) +from skops.io._visualize import visualize_tree class TestVisualizeTree: @@ -54,24 +50,6 @@ def pipeline_file(self, pipeline, tmp_path): sio.dump(pipeline, f_name) return f_name - @pytest.fixture - def side_effect_and_contents(self): - # This side effect collects the contents of what would normally be - # printed. That way, we can test more precisely than just capturing - # stdout and inspecting strings. - contents = [] - - def side_effect(node, key, level, trusted, show): - node_is_safe = node.is_self_safe() - node_and_children_are_safe = _check_node_and_children_safe(node, trusted) - should_print = _check_should_print( - node, node_is_safe, node_and_children_are_safe, show - ) - if should_print: - contents.append((node, key, level, trusted, show)) - - return side_effect, contents - @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) def test_print_simple(self, simple_file, show): visualize_tree(simple_file, show=show) @@ -79,11 +57,11 @@ def test_print_simple(self, simple_file, show): @pytest.mark.parametrize( "show_tell", [("all", 8), ("trusted", 8), ("untrusted", 0)] ) - def test_inspect_simple(self, simple_file, side_effect_and_contents, show_tell): - side_effect, contents = side_effect_and_contents + def test_inspect_simple(self, simple_file, show_tell): + nodes = [] show, expected_elements = show_tell - visualize_tree(simple_file, sink=side_effect, show=show) - assert len(contents) == expected_elements + visualize_tree(simple_file, sink=lambda n, _: nodes.extend(list(n)), show=show) + assert len([node for node in nodes if node.visible]) == expected_elements @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) def test_print_pipeline(self, pipeline_file, show): @@ -92,8 +70,10 @@ def test_print_pipeline(self, pipeline_file, show): @pytest.mark.parametrize( "show_tell", [("all", 129), ("trusted", 127), ("untrusted", 19)] ) - def test_inspect_pipeline(self, pipeline_file, side_effect_and_contents, show_tell): - side_effect, contents = side_effect_and_contents + def test_inspect_pipeline(self, pipeline_file, show_tell): + nodes = [] show, expected_elements = show_tell - visualize_tree(pipeline_file, sink=side_effect, show=show) - assert len(contents) == expected_elements + visualize_tree( + pipeline_file, sink=lambda n, _: nodes.extend(list(n)), show=show + ) + assert len([node for node in nodes if node.visible]) == expected_elements From 50a2125afec4459e99847235dc879921a1e8a43a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 12:06:29 +0100 Subject: [PATCH 05/12] Move format_node code to Node class --- skops/io/_audit.py | 4 ++++ skops/io/_visualize.py | 30 +----------------------------- 2 files changed, 5 insertions(+), 29 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 1e379308..5926c536 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -287,6 +287,10 @@ def get_unsafe_set(self) -> set[str]: return res + def format(self) -> str: + """Representation of the node's content.""" + return f"{self.module_name}.{self.class_name}" + class CachedNode(Node): def __init__( diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 76c35c3f..7b5d74c0 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -2,7 +2,6 @@ import json from dataclasses import dataclass -from functools import singledispatch from pathlib import Path from typing import Callable, Iterator, Literal, Sequence from zipfile import ZipFile @@ -85,33 +84,6 @@ def _check_node_and_children_safe(node: Node, trusted: bool | Sequence[str]) -> return node_and_children_are_safe -# use singledispatch so that we can register specialized visualization functions -@singledispatch -def format_node(node: Node) -> str: - """Format the name of the node. - - By default, this is just the fully qualified name of the class, e.g. - ``"sklearn.preprocessing._data.MinMaxScaler"``. But for some types of nodes, - having a more specific output is desirable. These node types can be - registered with this function. - - """ - return f"{node.module_name}.{node.class_name}" - - -@format_node.register -def _format_function_node(node: FunctionNode) -> str: - # if a FunctionNode, children are not visited, but safety should still be checked - child = node.children["content"] - fn_name = f"{child['module_path']}.{child['function']}" - return f"{node.module_name}.{node.class_name} => {fn_name}" - - -@format_node.register -def _format_json_node(node: JsonNode) -> str: - return f"json-type({node.content})" - - def walk_tree( node: Node | dict[str, Node] | Sequence[Node], trusted: bool | Sequence[str] = False, @@ -164,7 +136,7 @@ def walk_tree( show=show, ) - node_val = format_node(node) + node_val = node.format() tag = config.tag_safe if node_is_safe else config.tag_unsafe if tag: node_val += f" {tag}" From 37bf95f4c92caaf1e67c66793210a553c40f7cc6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 12:33:36 +0100 Subject: [PATCH 06/12] Move format_node code to Node class --- skops/io/_general.py | 8 ++++++++ skops/io/tests/test_audit.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index b5685c3f..dfb7f223 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -482,6 +482,14 @@ def get_unsafe_set(self) -> set[str]: def _construct(self): return json.loads(self.content) + def format(self) -> str: + """Representation of the node's content. + + Since no module is used, just show the content. + + """ + return f"json-type({self.content})" + def bytes_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: f_name = f"{uuid.uuid4()}.bin" diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 71914b4d..17115d79 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -12,8 +12,8 @@ from skops.io import dumps, get_untrusted_types from skops.io._audit import Node, audit_tree, check_type, get_tree, temp_setattr -from skops.io._general import DictNode, dict_get_state -from skops.io._utils import LoadContext, SaveContext, gettype +from skops.io._general import DictNode, JsonNode, ObjectNode, dict_get_state +from skops.io._utils import LoadContext, SaveContext, get_state, gettype class CustomType: @@ -163,3 +163,27 @@ def test_complex_pipeline_untrusted_set(): untrusted = get_untrusted_types(data=dumps(clf)) type_names = [x.split(".")[-1] for x in untrusted] assert type_names == ["sqrt", "square"] + + +def test_format_object_node(): + estimator = LogisticRegression(random_state=0, solver="liblinear") + state = get_state(estimator, SaveContext(None)) + node = ObjectNode(state, LoadContext(None)) + expected = "sklearn.linear_model._logistic.LogisticRegression" + assert node.format() == expected + + +@pytest.mark.parametrize( + "inp, expected", + [ + ("hello", 'json-type("hello")'), + (123, "json-type(123)"), + (True, "json-type(true)"), + (False, "json-type(false)"), + (None, "json-type(null)"), + ], +) +def test_format_json_node(inp, expected): + state = get_state(inp, SaveContext(None)) + node = JsonNode(state, LoadContext(None)) + assert node.format() == expected From f91b3131ecde9470afed4dfa4e62878beee42562 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 12:36:07 +0100 Subject: [PATCH 07/12] Simplify code by using node.is_safe() --- skops/io/_visualize.py | 39 +++++++-------------------------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 7b5d74c0..42c9866c 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Callable, Iterator, Literal, Sequence +from typing import Callable, Iterator, Literal from zipfile import ZipFile from ._audit import Node, get_tree @@ -17,7 +17,7 @@ class PrintConfig: # fmt: off tag_safe: str = "" # noqa: E222 - tag_unsafe: str = " [UNSAFE]" + tag_unsafe: str = "[UNSAFE]" line_start: str = "├─" line: str = "──" # noqa: E222 @@ -71,22 +71,8 @@ def _check_visibility( return should_print -def _check_node_and_children_safe(node: Node, trusted: bool | Sequence[str]) -> bool: - # Note: this is very inefficient, because get_unsafe_set will be called many - # times on the same node (since parents recursively call children) but maybe - # that's acceptable for this context. If not, caching could be an option. - if trusted is True: - node_and_children_are_safe = True - elif trusted is False: - node_and_children_are_safe = not node.get_unsafe_set() - else: - node_and_children_are_safe = not (node.get_unsafe_set() - set(trusted)) - return node_and_children_are_safe - - def walk_tree( - node: Node | dict[str, Node] | Sequence[Node], - trusted: bool | Sequence[str] = False, + node: Node | dict[str, Node] | list[Node], show: Literal["all", "untrusted", "trusted"] = "all", node_name: str = "root", level: int = 0, @@ -104,19 +90,17 @@ def walk_tree( val, node_name=key, level=level, - trusted=trusted, show=show, config=config, ) return - if isinstance(node, (list, tuple)): + if isinstance(node, (list, tuple)): # shouldn't be tuple, but to be sure for val in node: yield from walk_tree( val, node_name=node_name, level=level, - trusted=trusted, show=show, config=config, ) @@ -128,7 +112,7 @@ def walk_tree( # THE ACTUAL FORMATTING HAPPENS HERE node_is_safe = node.is_self_safe() - node_and_children_are_safe = _check_node_and_children_safe(node, trusted) + node_and_children_are_safe = node.is_safe() visible = _check_visibility( node, node_is_safe=node_is_safe, @@ -139,7 +123,7 @@ def walk_tree( node_val = node.format() tag = config.tag_safe if node_is_safe else config.tag_unsafe if tag: - node_val += f" {tag}" + node_val += f" {tag}".rstrip(" ") if config.use_colors: if node_and_children_are_safe: @@ -174,7 +158,6 @@ def walk_tree( node.children, node_name=node_name, level=level + 1, - trusted=trusted, show=show, config=config, ) @@ -182,7 +165,6 @@ def walk_tree( def visualize_tree( file: Path | str, # TODO: from bytes - trusted: bool | Sequence[str] = False, show: Literal["all", "untrusted", "trusted"] = "all", sink: Callable[[Iterator[FormattedNode], PrintConfig], None] = pretty_print_tree, print_config: PrintConfig = print_config, @@ -198,13 +180,6 @@ def visualize_tree( file: str or pathlib.Path The file name of the object to be loaded. - trusted: bool, or list of str, default=False - If ``True``, the object will be loaded without any security checks. If - ``False``, the object will be loaded only if there are only trusted - objects in the dumped file. If a list of strings, the object will be - loaded only if there are only trusted objects and objects of types - listed in ``trusted`` are in the dumped file. - show: "all" or "untrusted" or "trusted" Whether to print all nodes, only untrusted nodes, or only trusted nodes. @@ -230,5 +205,5 @@ def visualize_tree( schema = json.loads(zip_file.read("schema.json")) tree = get_tree(schema, load_context=LoadContext(src=zip_file)) - nodes = walk_tree(tree, trusted=trusted, show=show, config=print_config) + nodes = walk_tree(tree, show=show, config=print_config) sink(nodes, print_config) From 80bb8af27f42dfc241508c5ed2ad3e5d9924e627 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 12:46:32 +0100 Subject: [PATCH 08/12] Simplify check for visibility --- skops/io/_visualize.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 42c9866c..b239ca3e 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -56,19 +56,17 @@ def pretty_print_tree( print(line) -def _check_visibility( - node: Node, - node_is_safe: bool, - node_and_children_are_safe: bool, - show: Literal["all", "untrusted", "trusted"], -) -> bool: +def _check_visibility(node: Node, show: Literal["all", "untrusted", "trusted"]) -> bool: if show == "all": - should_print = True - elif show == "untrusted": - should_print = not node_and_children_are_safe - else: # only trusted - should_print = node_is_safe - return should_print + return True + + if show == "untrusted": + node_or_children_unsafe = not node.is_safe() + return node_or_children_unsafe + + # case: show only safe node + node_safe = node.is_self_safe() + return node_safe def walk_tree( @@ -111,21 +109,16 @@ def walk_tree( raise TypeError(f"{type(node)}") # THE ACTUAL FORMATTING HAPPENS HERE - node_is_safe = node.is_self_safe() - node_and_children_are_safe = node.is_safe() - visible = _check_visibility( - node, - node_is_safe=node_is_safe, - node_and_children_are_safe=node_and_children_are_safe, - show=show, - ) + visible = _check_visibility(node, show=show) node_val = node.format() + node_is_safe = node.is_self_safe() tag = config.tag_safe if node_is_safe else config.tag_unsafe if tag: node_val += f" {tag}".rstrip(" ") if config.use_colors: + node_and_children_are_safe = node.is_safe() if node_and_children_are_safe: color = config.color_safe elif node_is_safe: From eb56d884edff77f15bc0fda22a2976424e28fe54 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 16:02:54 +0100 Subject: [PATCH 09/12] Refactor: better separation of concerns --- skops/io/_visualize.py | 229 +++++++++++++++++++++---------- skops/io/tests/test_audit.py | 1 + skops/io/tests/test_visualize.py | 36 ++--- 3 files changed, 172 insertions(+), 94 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index b239ca3e..c5f2126d 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -3,9 +3,10 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Callable, Iterator, Literal +from typing import Any, Callable, Iterator, Literal from zipfile import ZipFile +from ..utils.importutils import import_or_raise from ._audit import Node, get_tree from ._general import FunctionNode, JsonNode from ._numpy import NdArrayNode @@ -15,67 +16,142 @@ @dataclass class PrintConfig: - # fmt: off - tag_safe: str = "" # noqa: E222 + tag_safe: str = "" tag_unsafe: str = "[UNSAFE]" - line_start: str = "├─" - line: str = "──" # noqa: E222 - use_colors: bool = True - color_safe: str = '\033[32m' # green # noqa: E222 - color_unsafe: str = '\033[31m' # red # noqa: E222 - color_child_unsafe: str = '\033[33m' # yellow - color_end: str = '\033[0m' # noqa: E222 - # fmt: on + # use rich for coloring + color_safe: str = "green" + color_unsafe: str = "red" + color_child_unsafe: str = "yellow" print_config = PrintConfig() +def _check_visibility( + is_self_safe: bool, + is_safe: bool, + show: Literal["all", "untrusted", "trusted"], +) -> bool: + if show == "all": + return True + + if show == "untrusted": + return not is_safe + + # case: show only safe node + return is_self_safe + + @dataclass -class FormattedNode: +class NodeInfo: level: int key: str # the key to the node val: str # the value of the node - visible: bool # whether it should be shown + is_self_safe: bool # whether this specific node is safe + is_safe: bool # whether this node and all of its children are safe def pretty_print_tree( - formatted_nodes: Iterator[FormattedNode], config: PrintConfig + formatted_nodes: Iterator[NodeInfo], + show: Literal["all", "untrusted", "trusted"], + config: PrintConfig, ) -> None: - # TODO: the "tree" lines could be made prettier since all nodes are known - # here - for formatted_node in formatted_nodes: - if not formatted_node.visible: + # TODO: print html representation if inside a notebook + rich = import_or_raise("rich", "pretty printing the object") + from rich.tree import Tree + + nodes = list(formatted_nodes) + node = nodes.pop(0) + cur_level = 0 + root = tree = Tree(f"{node.key}: {node.val}") + trace = [tree] # trace keeps track of what is the current node to add to + + while nodes: + node = nodes.pop(0) + visible = _check_visibility(node.is_self_safe, node.is_safe, show=show) + if not visible: continue - line = print_config.line_start - line += (formatted_node.level - 1) * print_config.line - line += f"{formatted_node.key}: {formatted_node.val}" - print(line) + level_diff = cur_level - node.level + if level_diff < -1: + # this would mean it is a "(great-)grandchild" node + raise ValueError( + "While constructing the tree of the object, a level difference of " + f"{level_diff} was encountered, which should not be possible, please " + "report the issue here: https://github.com/skops-dev/skops/issues" + ) + for _ in range(level_diff + 1): + trace.pop(-1) + # If level_diff == -1, we're fine with not popping, as the code + # assumes that the current node is a child node of the previous one, + # which corresponds to a level_diff of -1. -def _check_visibility(node: Node, show: Literal["all", "untrusted", "trusted"]) -> bool: - if show == "all": - return True + # add unsafe tag if necessary + node_val = node.val + tag = config.tag_safe if node.is_self_safe else config.tag_unsafe + if tag: + node_val += f" {tag}".rstrip(" ") - if show == "untrusted": - node_or_children_unsafe = not node.is_safe() - return node_or_children_unsafe + # colorize if so desired + if config.use_colors: + if node.is_safe: + color = config.color_safe + elif node.is_self_safe: + color = config.color_child_unsafe + else: + color = config.color_unsafe + node_val = f"[{color}]{node_val}" - # case: show only safe node - node_safe = node.is_self_safe() - return node_safe + text = f"{node.key}: {node_val}" + tree = trace[-1] + trace.append(tree.add(text)) + cur_level = node.level + + rich.print(root) def walk_tree( node: Node | dict[str, Node] | list[Node], - show: Literal["all", "untrusted", "trusted"] = "all", node_name: str = "root", level: int = 0, - config: PrintConfig = print_config, -) -> Iterator[FormattedNode]: +) -> Iterator[NodeInfo]: + """Visit all nodes of the tree and yield their important attributes. + + This function visits all nodes of the object tree and determines: + + - level: how nested the node is + - key: the key of the node, e.g. the key of a dict. + - val: the value of the node, e.g. builtins.list + - safety: whether it, and its children, are trusted + + These values are just yielded in a flat manner. This way, the consumer of + this function doesn't need to know how nodes can be nested and how safety of + a node is determined. + + Parameters + ---------- + node: :class:`skops.io._audit.Node` + The current node to visit. Children are visited recursively. + + show: "all" or "untrusted" or "trusted" + Whether to print all nodes, only untrusted nodes, or only trusted nodes. + + node_name: str (default="root") + The key to the current node. If "key_types" is encountered, it is + skipped. + + level: int (default=0) + The current level of nesting. + + Yields + ------ + :class:`~NodeInfo`: + A dataclass containing the aforementioned information. + + """ # helper function to pretty-print the nodes if node_name == "key_types": # _check_key_types_schema(node) @@ -88,46 +164,38 @@ def walk_tree( val, node_name=key, level=level, - show=show, - config=config, ) return - if isinstance(node, (list, tuple)): # shouldn't be tuple, but to be sure + if isinstance(node, (list, tuple)): + # shouldn't be tuple, but check just to be sure for val in node: yield from walk_tree( val, node_name=node_name, level=level, - show=show, - config=config, ) return # NO MATCH: RAISE ERROR if not isinstance(node, Node): - raise TypeError(f"{type(node)}") - - # THE ACTUAL FORMATTING HAPPENS HERE - visible = _check_visibility(node, show=show) - - node_val = node.format() - node_is_safe = node.is_self_safe() - tag = config.tag_safe if node_is_safe else config.tag_unsafe - if tag: - node_val += f" {tag}".rstrip(" ") - - if config.use_colors: - node_and_children_are_safe = node.is_safe() - if node_and_children_are_safe: - color = config.color_safe - elif node_is_safe: - color = config.color_child_unsafe - else: - color = config.color_unsafe - node_val = f"{color}{node_val}{config.color_end}" - - yield FormattedNode(level=level, key=node_name, val=node_val, visible=visible) + raise TypeError( + f"Cannot deal with {type(node)}, please report the issue here " + "https://github.com/skops-dev/skops/issues" + ) + + # YIELDING THE ACTUAL FORMATTED NODE HERE + + # Note: calling node.is_safe() on all nodes is potentially wasteful because + # it is already a recursive call, i.e. child nodes will be checked many + # times. A solution to this would be to add caching to its call. + yield NodeInfo( + level=level, + key=node_name, + val=node.format(), + is_self_safe=node.is_self_safe(), + is_safe=node.is_safe(), + ) # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): @@ -151,16 +219,16 @@ def walk_tree( node.children, node_name=node_name, level=level + 1, - show=show, - config=config, ) def visualize_tree( file: Path | str, # TODO: from bytes show: Literal["all", "untrusted", "trusted"] = "all", - sink: Callable[[Iterator[FormattedNode], PrintConfig], None] = pretty_print_tree, - print_config: PrintConfig = print_config, + sink: Callable[ + [Iterator[NodeInfo], Literal["all", "untrusted", "trusted"], PrintConfig], None + ] = pretty_print_tree, + **kwargs: Any, ) -> None: """Visualize the contents of a skops file. @@ -177,26 +245,35 @@ def visualize_tree( Whether to print all nodes, only untrusted nodes, or only trusted nodes. sink: function - This function should take two arguments, an iterator of - ``FormattedNode`` and a ``PrintConfig``. The ``FormattedNode`` contains - the information about the node, namely: + + This function should take three arguments, an iterator of + :class:`~NodeInfo` instances, an indicator of what to show, and a config + of :class:`~PrintConfig`. The ``NodeInfo`` contains the information + about the node, namely: - the level of nesting (int) - the key of the node (str) - the value of the node as a string representation (str) - - the visibility of the node, depending on the ``show`` argument (bool) + - the safety of the node and its children - The second argument is the print config (see description of next argument). + The ``show`` argument is explained above. - print_config: :class:`~PrintConfig` - The ``PrintConfig`` is a simple object with attributes that determine - how the node should be visualized, e.g. the ``use_colors`` attribute - determines if colors should be used. + The last argument is a :class:`~PrintConfig` instance, which is a + simple dataclass with attributes that determine how the node should be + visualized, e.g. the ``use_colors`` attribute determines if colors + should be used. + + kwargs : TODO """ with ZipFile(file, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) tree = get_tree(schema, load_context=LoadContext(src=zip_file)) - nodes = walk_tree(tree, show=show, config=print_config) - sink(nodes, print_config) + if kwargs: + config = PrintConfig(**kwargs) + else: + config = print_config + + nodes = walk_tree(tree) + sink(nodes, show, config) diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 17115d79..24cdbc94 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -178,6 +178,7 @@ def test_format_object_node(): [ ("hello", 'json-type("hello")'), (123, "json-type(123)"), + (0.456, "json-type(0.456)"), (True, "json-type(true)"), (False, "json-type(false)"), (None, "json-type(null)"), diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 39b7104f..d2d7c810 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -54,26 +54,26 @@ def pipeline_file(self, pipeline, tmp_path): def test_print_simple(self, simple_file, show): visualize_tree(simple_file, show=show) - @pytest.mark.parametrize( - "show_tell", [("all", 8), ("trusted", 8), ("untrusted", 0)] - ) - def test_inspect_simple(self, simple_file, show_tell): - nodes = [] - show, expected_elements = show_tell - visualize_tree(simple_file, sink=lambda n, _: nodes.extend(list(n)), show=show) - assert len([node for node in nodes if node.visible]) == expected_elements + # @pytest.mark.parametrize( + # "show_tell", [("all", 8), ("trusted", 8), ("untrusted", 0)] + # ) + # def test_inspect_simple(self, simple_file, show_tell): + # nodes = [] + # show, expected_elements = show_tell + # visualize_tree(simple_file, sink=lambda n, *_: nodes.extend(list(n)), show=show) + # assert len([node for node in nodes if node.visible]) == expected_elements @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) def test_print_pipeline(self, pipeline_file, show): visualize_tree(pipeline_file, show=show) - @pytest.mark.parametrize( - "show_tell", [("all", 129), ("trusted", 127), ("untrusted", 19)] - ) - def test_inspect_pipeline(self, pipeline_file, show_tell): - nodes = [] - show, expected_elements = show_tell - visualize_tree( - pipeline_file, sink=lambda n, _: nodes.extend(list(n)), show=show - ) - assert len([node for node in nodes if node.visible]) == expected_elements + # @pytest.mark.parametrize( + # "show_tell", [("all", 129), ("trusted", 127), ("untrusted", 19)] + # ) + # def test_inspect_pipeline(self, pipeline_file, show_tell): + # nodes = [] + # show, expected_elements = show_tell + # visualize_tree( + # pipeline_file, sink=lambda n, *_: nodes.extend(list(n)), show=show + # ) + # assert len([node for node in nodes if node.visible]) == expected_elements From 3cd95d63dd4ce5c3eaf100b2375b5dc4e2a44140 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 17:02:23 +0100 Subject: [PATCH 10/12] Allow reading from bytes, better tests --- skops/io/_visualize.py | 18 +++-- skops/io/tests/test_visualize.py | 120 +++++++++++++++++++++---------- 2 files changed, 93 insertions(+), 45 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index c5f2126d..1d23f86e 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import json from dataclasses import dataclass from pathlib import Path @@ -54,7 +55,7 @@ class NodeInfo: def pretty_print_tree( - formatted_nodes: Iterator[NodeInfo], + nodes_iter: Iterator[NodeInfo], show: Literal["all", "untrusted", "trusted"], config: PrintConfig, ) -> None: @@ -62,7 +63,7 @@ def pretty_print_tree( rich = import_or_raise("rich", "pretty printing the object") from rich.tree import Tree - nodes = list(formatted_nodes) + nodes = list(nodes_iter) node = nodes.pop(0) cur_level = 0 root = tree = Tree(f"{node.key}: {node.val}") @@ -93,7 +94,7 @@ def pretty_print_tree( node_val = node.val tag = config.tag_safe if node.is_self_safe else config.tag_unsafe if tag: - node_val += f" {tag}".rstrip(" ") + node_val += f" {tag}" # colorize if so desired if config.use_colors: @@ -184,7 +185,7 @@ def walk_tree( "https://github.com/skops-dev/skops/issues" ) - # YIELDING THE ACTUAL FORMATTED NODE HERE + # YIELDING THE ACTUAL NODE INFORMATION HERE # Note: calling node.is_safe() on all nodes is potentially wasteful because # it is already a recursive call, i.e. child nodes will be checked many @@ -223,7 +224,7 @@ def walk_tree( def visualize_tree( - file: Path | str, # TODO: from bytes + file: Path | str | bytes, show: Literal["all", "untrusted", "trusted"] = "all", sink: Callable[ [Iterator[NodeInfo], Literal["all", "untrusted", "trusted"], PrintConfig], None @@ -266,7 +267,12 @@ def visualize_tree( kwargs : TODO """ - with ZipFile(file, "r") as zip_file: + if isinstance(file, bytes): + zf = ZipFile(io.BytesIO(file), "r") + else: + zf = ZipFile(file, "r") + + with zf as zip_file: schema = json.loads(zip_file.read("schema.json")) tree = get_tree(schema, load_context=LoadContext(src=zip_file)) diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index d2d7c810..7d2e8d87 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -1,5 +1,5 @@ -import numpy as np import pytest +import sklearn from sklearn.linear_model import LogisticRegression from sklearn.pipeline import FeatureUnion, Pipeline from sklearn.preprocessing import ( @@ -18,14 +18,11 @@ class TestVisualizeTree: def simple(self): return MinMaxScaler(feature_range=(-555, 123)) - @pytest.fixture - def simple_file(self, simple, tmp_path): - f_name = tmp_path / "estimator.skops" - sio.dump(simple, f_name) - return f_name - @pytest.fixture def pipeline(self): + def unsafe_function(x): + return x + # fmt: off pipeline = Pipeline([ ("features", FeatureUnion([ @@ -35,7 +32,7 @@ def pipeline(self): ("poly1", PolynomialFeatures()), ("poly2", PolynomialFeatures(degree=3, include_bias=False)) ])), - ("square-root", FunctionTransformer(np.sqrt)), + ("square-root", FunctionTransformer(unsafe_function)), ("scale", MinMaxScaler()), ])), ])), @@ -44,36 +41,81 @@ def pipeline(self): # fmt: on return pipeline - @pytest.fixture - def pipeline_file(self, pipeline, tmp_path): - f_name = tmp_path / "estimator.skops" - sio.dump(pipeline, f_name) - return f_name - @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) - def test_print_simple(self, simple_file, show): - visualize_tree(simple_file, show=show) - - # @pytest.mark.parametrize( - # "show_tell", [("all", 8), ("trusted", 8), ("untrusted", 0)] - # ) - # def test_inspect_simple(self, simple_file, show_tell): - # nodes = [] - # show, expected_elements = show_tell - # visualize_tree(simple_file, sink=lambda n, *_: nodes.extend(list(n)), show=show) - # assert len([node for node in nodes if node.visible]) == expected_elements + def test_print_simple(self, simple, show, capsys): + file = sio.dumps(simple) + visualize_tree(file, show=show) - @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) - def test_print_pipeline(self, pipeline_file, show): - visualize_tree(pipeline_file, show=show) - - # @pytest.mark.parametrize( - # "show_tell", [("all", 129), ("trusted", 127), ("untrusted", 19)] - # ) - # def test_inspect_pipeline(self, pipeline_file, show_tell): - # nodes = [] - # show, expected_elements = show_tell - # visualize_tree( - # pipeline_file, sink=lambda n, *_: nodes.extend(list(n)), show=show - # ) - # assert len([node for node in nodes if node.visible]) == expected_elements + # output is always the same for "all" and "trusted" because all nodes + # are trusted + expected = [ + "root: sklearn.preprocessing._data.MinMaxScaler", + "└── attrs: builtins.dict", + " ├── feature_range: builtins.tuple", + " │ ├── content: json-type(-555)", + " │ └── content: json-type(123)", + " ├── copy: json-type(true)", + " ├── clip: json-type(false)", + f' └── _sklearn_version: json-type("{sklearn.__version__}")', + ] + if show == "untrusted": + # since no untrusted, only show root + expected = expected[:1] + + stdout, _ = capsys.readouterr() + assert stdout.strip() == "\n".join(expected) + + def test_print_pipelien(self, pipeline, capsys): + file = sio.dumps(pipeline) + visualize_tree(file) + + # no point in checking the whole output with > 120 lines + expected_start = [ + "root: sklearn.pipeline.Pipeline", + "└── attrs: builtins.dict", + " ├── steps: builtins.list", + " │ ├── content: builtins.tuple", + ' │ │ ├── content: json-type("features")', + ] + expected_end = [ + " ├── memory: json-type(null)", + " ├── verbose: json-type(false)", + f' └── _sklearn_version: json-type("{sklearn.__version__}")', + ] + + stdout, _ = capsys.readouterr() + assert stdout.startswith("\n".join(expected_start)) + assert stdout.rstrip().endswith("\n".join(expected_end)) + + def test_unsafe_nodes(self, pipeline): + file = sio.dumps(pipeline) + nodes = [] + + def sink(nodes_iter, *args, **kwargs): + nodes.extend(nodes_iter) + + visualize_tree(file, sink=sink) + nodes_self_unsafe = [node for node in nodes if not node.is_self_safe] + nodes_unsafe = [node for node in nodes if not node.is_safe] + + # there are currently 2 unsafe nodes, a numpy int and the custom + # functions. The former might be considered safe in the future, in which + # case this test needs to be changed. + assert len(nodes_self_unsafe) == 2 + assert nodes_self_unsafe[0].val == "numpy.int64" + assert nodes_self_unsafe[1].val == "test_visualize.function" + + # it's not easy to test the number of indirectly unsafe nodes, because + # it will depend on the nesting structure; we can only be sure that it's + # more than 2, and one of them should be the FunctionTransformer + assert len(nodes_unsafe) > 2 + assert any("FunctionTransformer" in node.val for node in nodes_unsafe) + + def test_custom_print_config(self): + pass + + def test_from_file(self): + pass + + def test_rich_not_installed(self): + pass From 0e6c71fb1dc1874bcdd3afd3d4efcf4daee68071 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 17 Mar 2023 14:55:43 +0100 Subject: [PATCH 11/12] Add tests, docs, improved implementation --- docs/changes.rst | 3 + docs/persistence.rst | 28 ++++- skops/_min_dependencies.py | 1 + skops/io/__init__.py | 3 +- skops/io/_visualize.py | 181 +++++++++++++++++++------------ skops/io/tests/test_visualize.py | 110 ++++++++++++++++--- 6 files changed, 239 insertions(+), 87 deletions(-) diff --git a/docs/changes.rst b/docs/changes.rst index e7d0155c..80e9b409 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -20,6 +20,9 @@ v0.6 `Benjamin Bossan`_. - Fix: skops persistence now also works with many functions from the :mod:`operator` module. :pr:`287` by `Benjamin Bossan`_. +- Add possibility to visualize a skops object and show untrusted types by using + :func:`skops.io.visualize_tree`. Requires to install `rich`: `pip install + rich`. :pr:`317` by `Benjamin Bossan`_. v0.5 ---- diff --git a/docs/persistence.rst b/docs/persistence.rst index 341b8c74..7cedcc24 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -134,11 +134,37 @@ For example, to convert all ``.pkl`` flies in the current directory: Further help for the different supported options can be found by calling ``skops convert --help`` in a terminal. +Visualization +############# + +Skops files can be visualized using :func:`skops.io.visualize_tree`. If you have +a skops file called ``my-model.skops``, you can visualize it like this: + +.. code:: python + + import skops.io as sio + sio.visualize_tree("my-model.skops") + +The output could look like this: + +.. code:: + + root: sklearn.preprocessing._data.MinMaxScaler + └── attrs: builtins.dict + ├── feature_range: builtins.tuple + │ ├── content: json-type(-555) + │ └── content: json-type(123) + ├── copy: unsafe_lib.UnsafeType [UNSAFE] + ├── clip: json-type(false) + └── _sklearn_version: json-type("1.2.0") + +``unsafe_lib.UnsafeType`` was recognized as untrusted and marked. There are +various options, like colorizing nodes that are untrusted. Supported libraries ------------------- -Skops intends to support all of **scikit-learn**, that is, not only its +Skops intends to support all of ````scikit-learn**, that is, not only its estimators, but also other classes like cross validation splitters. Furthermore, most types from **numpy** and **scipy** should be supported, such as (sparse) arrays, dtypes, random generators, and ufuncs. diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index ff3483d1..06b65524 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -34,6 +34,7 @@ # TODO: remove condition when catboost supports python 3.11 "catboost": ("1.0", "tests", "python_version < '3.11'"), "fairlearn": ("0.7.0", "docs, tests", None), + "rich": ("12", "tests", None), } diff --git a/skops/io/__init__.py b/skops/io/__init__.py index 7990e9eb..d381d49d 100644 --- a/skops/io/__init__.py +++ b/skops/io/__init__.py @@ -1,3 +1,4 @@ from ._persist import dump, dumps, get_untrusted_types, load, loads +from ._visualize import visualize_tree -__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types"] +__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types", "visualize_tree"] diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 1d23f86e..ca1b9da8 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -16,18 +16,12 @@ @dataclass -class PrintConfig: - tag_safe: str = "" - tag_unsafe: str = "[UNSAFE]" - - use_colors: bool = True - # use rich for coloring - color_safe: str = "green" - color_unsafe: str = "red" - color_child_unsafe: str = "yellow" - - -print_config = PrintConfig() +class NodeInfo: + level: int + key: str # the key to the node + val: str # the value of the node + is_self_safe: bool # whether this specific node is safe + is_safe: bool # whether this node and all of its children are safe def _check_visibility( @@ -35,6 +29,12 @@ def _check_visibility( is_safe: bool, show: Literal["all", "untrusted", "trusted"], ) -> bool: + """Determine visibility of the node. + + Users can indicate if they want to see all nodes, all trusted nodes, or all + untrusted nodes. + + """ if show == "all": return True @@ -45,28 +45,63 @@ def _check_visibility( return is_self_safe -@dataclass -class NodeInfo: - level: int - key: str # the key to the node - val: str # the value of the node - is_self_safe: bool # whether this specific node is safe - is_safe: bool # whether this node and all of its children are safe +def _get_node_label( + node: NodeInfo, + tag_safe: str = "", + tag_unsafe: str = "[UNSAFE]", + use_colors: bool = True, + # use rich for coloring + color_safe: str = "green", + color_unsafe: str = "red", + color_child_unsafe: str = "yellow", +): + """Determine the label of a node. + + Nodes are labeled differently based on how they're trusted. + + """ + # note: when changing the arguments to this function, also update the + # docstring of visualize_tree! + + # add tag if necessary + node_val = node.val + tag = tag_safe if node.is_self_safe else tag_unsafe + if tag: + node_val += f" {tag}" + + # colorize if so desired + if use_colors: + if node.is_safe: + color = color_safe + elif node.is_self_safe: + color = color_child_unsafe + else: + color = color_unsafe + node_val = f"[{color}]{node_val}" + + return node_val def pretty_print_tree( nodes_iter: Iterator[NodeInfo], show: Literal["all", "untrusted", "trusted"], - config: PrintConfig, + **kwargs, ) -> None: - # TODO: print html representation if inside a notebook + # This function loops through the flattened nodes of the tree and creates a + # rich Tree based on the node information. Rich can then create a pretty + # visualization of said tree. rich = import_or_raise("rich", "pretty printing the object") from rich.tree import Tree nodes = list(nodes_iter) + if not nodes: # empty tree, hmm + return + + # start with root node, it is always visible node = nodes.pop(0) - cur_level = 0 - root = tree = Tree(f"{node.key}: {node.val}") + node_label = _get_node_label(node, **kwargs) + cur_level = node.level # should be 0 + root = tree = Tree(f"{node.key}: {node_label}") trace = [tree] # trace keeps track of what is the current node to add to while nodes: @@ -84,29 +119,23 @@ def pretty_print_tree( "report the issue here: https://github.com/skops-dev/skops/issues" ) + # Level diff of -1 means that this node is a child of the previous node. + # E.g. if the current level if 4 and the previous level was 3, the + # current node is a child node of the previous one. Since the previous + # node is already the last node in the trace, there is nothing more that + # needs to be done. Therefore, for a diff of -1, we don't pop from the + # trace. for _ in range(level_diff + 1): + # If the level diff is greater than -1, it means that the current + # node is not the child of the last node, but of a node higher up. + # E.g. if the current level is 2 and previous level was 3, it means + # that we should move up 2 layers of nesting, therefore, we pop + # 3-2+1 = 2 levels. trace.pop(-1) - # If level_diff == -1, we're fine with not popping, as the code - # assumes that the current node is a child node of the previous one, - # which corresponds to a level_diff of -1. - - # add unsafe tag if necessary - node_val = node.val - tag = config.tag_safe if node.is_self_safe else config.tag_unsafe - if tag: - node_val += f" {tag}" - - # colorize if so desired - if config.use_colors: - if node.is_safe: - color = config.color_safe - elif node.is_self_safe: - color = config.color_child_unsafe - else: - color = config.color_unsafe - node_val = f"[{color}]{node_val}" - - text = f"{node.key}: {node_val}" + + # add tag if necessary + node_label = _get_node_label(node, **kwargs) + text = f"{node.key}: {node_label}" tree = trace[-1] trace.append(tree.add(text)) cur_level = node.level @@ -153,9 +182,11 @@ def walk_tree( A dataclass containing the aforementioned information. """ - # helper function to pretty-print the nodes + # key_types is not helpful, as it is artificially added by skops to + # circumvent the fact that json only allows keys to be strings. It is not + # useful to the user and adds a lot of noise, thus skip key_types. + # TODO: check that no funny business is going on in key types if node_name == "key_types": - # _check_key_types_schema(node) return # COMPOSITE TYPES: CHECK ALL ITEMS @@ -199,20 +230,18 @@ def walk_tree( ) # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT + # TODO: For better security, we should check the schema if we return early, + # otherwise something nefarious could be hidden inside. if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): - # _check_array_schema(node) return if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type == "json"): - # _check_array_json_schema(node) return if isinstance(node, FunctionNode): - # _check_function_schema(node) return if isinstance(node, JsonNode): - # _check_json_schema(node) pass # RECURSE @@ -226,9 +255,7 @@ def walk_tree( def visualize_tree( file: Path | str | bytes, show: Literal["all", "untrusted", "trusted"] = "all", - sink: Callable[ - [Iterator[NodeInfo], Literal["all", "untrusted", "trusted"], PrintConfig], None - ] = pretty_print_tree, + sink: Callable[..., None] = pretty_print_tree, **kwargs: Any, ) -> None: """Visualize the contents of a skops file. @@ -237,6 +264,14 @@ def visualize_tree( untrusted nodes. A node is considered untrusted if at least one of its child nodes is untrusted. + Visualizing the tree using the default visualization function requires the + ``rich`` library, which can be installed as: + + python -m pip install rich + + If passing a custom visualization function to ``sink``, ``rich`` is not + required. + Parameters ---------- file: str or pathlib.Path @@ -245,26 +280,36 @@ def visualize_tree( show: "all" or "untrusted" or "trusted" Whether to print all nodes, only untrusted nodes, or only trusted nodes. - sink: function + sink: function (default=:func:`~pretty_print_tree`) - This function should take three arguments, an iterator of - :class:`~NodeInfo` instances, an indicator of what to show, and a config - of :class:`~PrintConfig`. The ``NodeInfo`` contains the information - about the node, namely: + This function should take at least two arguments, an iterator of + :class:`~NodeInfo` instances and an indicator of what to show. The + ``NodeInfo`` contains the information about the node, namely: - the level of nesting (int) - the key of the node (str) - the value of the node as a string representation (str) - the safety of the node and its children - The ``show`` argument is explained above. + The ``show`` argument is explained above. Any additional ``kwargs`` + passed to ``visualize_tree`` will also be passed to ``sink``. + + The default sink is :func:`~pretty_print_tree`, which takes these + additional parameters: - The last argument is a :class:`~PrintConfig` instance, which is a - simple dataclass with attributes that determine how the node should be - visualized, e.g. the ``use_colors`` attribute determines if colors - should be used. + - tag_safe: The tag used to mark trusted nodes (default="", i.e no + tag) + - tag_unsafe: The tag used to mark untrusted nodes + (default="[UNSAFE]") + - use_colors: Whether to colorize the nodes (default=True) + - color_safe: Color to use for trusted nodes (default="green") + - color_unsafe: Color to use for untrusted nodes (default="red") + - color_child_unsafe: Color to use for nodes that are trusted but + that have untrusted child ndoes (default="yellow") - kwargs : TODO + So if you don't want to have colored output, just pass + ``use_colors=False`` to ``visualize_tree``. The colors themselves, such + as "red" and "green", refer to the standard colors used by ``rich``. """ if isinstance(file, bytes): @@ -276,10 +321,6 @@ def visualize_tree( schema = json.loads(zip_file.read("schema.json")) tree = get_tree(schema, load_context=LoadContext(src=zip_file)) - if kwargs: - config = PrintConfig(**kwargs) - else: - config = print_config - nodes = walk_tree(tree) - sink(nodes, show, config) + # TODO: it would be nice to print html representation if inside a notebook + sink(nodes, show, **kwargs) diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 7d2e8d87..9faf44e5 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -1,3 +1,7 @@ +"""Tests for skops.io.visualize""" + +from unittest.mock import Mock, patch + import pytest import sklearn from sklearn.linear_model import LogisticRegression @@ -10,7 +14,6 @@ ) import skops.io as sio -from skops.io._visualize import visualize_tree class TestVisualizeTree: @@ -44,10 +47,10 @@ def unsafe_function(x): @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) def test_print_simple(self, simple, show, capsys): file = sio.dumps(simple) - visualize_tree(file, show=show) + sio.visualize_tree(file, show=show) - # output is always the same for "all" and "trusted" because all nodes - # are trusted + # Output is the same for "all" and "trusted" because all nodes are + # trusted. Colors are not recorded by capsys. expected = [ "root: sklearn.preprocessing._data.MinMaxScaler", "└── attrs: builtins.dict", @@ -56,7 +59,7 @@ def test_print_simple(self, simple, show, capsys): " │ └── content: json-type(123)", " ├── copy: json-type(true)", " ├── clip: json-type(false)", - f' └── _sklearn_version: json-type("{sklearn.__version__}")', + ' └── _sklearn_version: json-type("{}")'.format(sklearn.__version__), ] if show == "untrusted": # since no untrusted, only show root @@ -65,9 +68,9 @@ def test_print_simple(self, simple, show, capsys): stdout, _ = capsys.readouterr() assert stdout.strip() == "\n".join(expected) - def test_print_pipelien(self, pipeline, capsys): + def test_print_pipeline(self, pipeline, capsys): file = sio.dumps(pipeline) - visualize_tree(file) + sio.visualize_tree(file) # no point in checking the whole output with > 120 lines expected_start = [ @@ -80,7 +83,7 @@ def test_print_pipelien(self, pipeline, capsys): expected_end = [ " ├── memory: json-type(null)", " ├── verbose: json-type(false)", - f' └── _sklearn_version: json-type("{sklearn.__version__}")', + ' └── _sklearn_version: json-type("{}")'.format(sklearn.__version__), ] stdout, _ = capsys.readouterr() @@ -94,7 +97,7 @@ def test_unsafe_nodes(self, pipeline): def sink(nodes_iter, *args, **kwargs): nodes.extend(nodes_iter) - visualize_tree(file, sink=sink) + sio.visualize_tree(file, sink=sink) nodes_self_unsafe = [node for node in nodes if not node.is_self_safe] nodes_unsafe = [node for node in nodes if not node.is_safe] @@ -111,11 +114,88 @@ def sink(nodes_iter, *args, **kwargs): assert len(nodes_unsafe) > 2 assert any("FunctionTransformer" in node.val for node in nodes_unsafe) - def test_custom_print_config(self): - pass + @pytest.mark.parametrize( + "kwargs", + [ + {}, + {"use_colors": False}, + {"tag_unsafe": "", "color_unsafe": "blue"}, + ], + ) + def test_custom_print_config_passed_to_sink(self, simple, kwargs): + # check that arguments are passed to sink + def my_sink(nodes_iter, show, **sink_kwargs): + for key, val in kwargs.items(): + assert sink_kwargs[key] == val + + file = sio.dumps(simple) + sio.visualize_tree(file, sink=my_sink, **kwargs) + + def test_custom_tags(self, simple, capsys): + class UnsafeType: + pass + + simple.copy = UnsafeType + + file = sio.dumps(simple) + sio.visualize_tree(file, tag_safe="NICE", tag_unsafe="OHNO") + expected = [ + "root: sklearn.preprocessing._data.MinMaxScaler NICE", + "└── attrs: builtins.dict NICE", + " ├── feature_range: builtins.tuple NICE", + " │ ├── content: json-type(-555) NICE", + " │ └── content: json-type(123) NICE", + " ├── copy: test_visualize.UnsafeType OHNO", + " ├── clip: json-type(false) NICE", + ' └── _sklearn_version: json-type("{}") NICE'.format( + sklearn.__version__ + ), + ] + + stdout, _ = capsys.readouterr() + assert stdout.strip() == "\n".join(expected) + + def test_custom_colors(self, simple): + # Colors are not recorded by capsys, so we cannot use it + class UnsafeType: + pass - def test_from_file(self): - pass + simple.copy = UnsafeType - def test_rich_not_installed(self): - pass + file = sio.dumps(simple) + mock_print = Mock() + with patch("rich.print", mock_print): + sio.visualize_tree( + file, + color_safe="black", + color_unsafe="cyan", + color_child_unsafe="orange3", + ) + + assert mock_print.call_count == 1 + + tree = mock_print.call_args_list[0].args[0] + # The root node is indirectly unsafe through child + assert "[orange3]" in tree.label + # feature_range is safe + assert "[black]" in tree.children[0].children[0].label + # copy is unsafe + assert "[cyan]" in tree.children[0].children[1].label + + def test_from_file(self, simple, tmp_path, capsys): + f_name = tmp_path / "estimator.skops" + sio.dump(simple, f_name) + sio.visualize_tree(f_name) + + expected = [ + "root: sklearn.preprocessing._data.MinMaxScaler", + "└── attrs: builtins.dict", + " ├── feature_range: builtins.tuple", + " │ ├── content: json-type(-555)", + " │ └── content: json-type(123)", + " ├── copy: json-type(true)", + " ├── clip: json-type(false)", + ' └── _sklearn_version: json-type("{}")'.format(sklearn.__version__), + ] + stdout, _ = capsys.readouterr() + assert stdout.strip() == "\n".join(expected) From 1853dcfc848c94122f58edb52894d1fcc9db3276 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 22 Mar 2023 14:38:11 +0100 Subject: [PATCH 12/12] Address reviewer comments + some more - rename visualize_tree => visualize - more explanatory comments - when encountering 'key_types', check type and safety - possibility to visualize without rich - make rich an extra install - more documentation --- docs/changes.rst | 2 +- docs/persistence.rst | 42 +++++++-- setup.py | 1 + skops/_min_dependencies.py | 7 +- skops/conftest.py | 13 +++ skops/io/__init__.py | 4 +- skops/io/_visualize.py | 150 ++++++++++++++++++------------- skops/io/tests/test_visualize.py | 89 ++++++++++++++---- 8 files changed, 220 insertions(+), 88 deletions(-) diff --git a/docs/changes.rst b/docs/changes.rst index 80e9b409..cf85ae05 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -21,7 +21,7 @@ v0.6 - Fix: skops persistence now also works with many functions from the :mod:`operator` module. :pr:`287` by `Benjamin Bossan`_. - Add possibility to visualize a skops object and show untrusted types by using - :func:`skops.io.visualize_tree`. Requires to install `rich`: `pip install + :func:`skops.io.visualize`. For colored output, install `rich`: `pip install rich`. :pr:`317` by `Benjamin Bossan`_. v0.5 diff --git a/docs/persistence.rst b/docs/persistence.rst index 7cedcc24..266d171b 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -137,13 +137,13 @@ Further help for the different supported options can be found by calling Visualization ############# -Skops files can be visualized using :func:`skops.io.visualize_tree`. If you have +Skops files can be visualized using :func:`skops.io.visualize`. If you have a skops file called ``my-model.skops``, you can visualize it like this: .. code:: python import skops.io as sio - sio.visualize_tree("my-model.skops") + sio.visualize("my-model.skops") The output could look like this: @@ -158,13 +158,45 @@ The output could look like this: ├── clip: json-type(false) └── _sklearn_version: json-type("1.2.0") -``unsafe_lib.UnsafeType`` was recognized as untrusted and marked. There are -various options, like colorizing nodes that are untrusted. +``unsafe_lib.UnsafeType`` was recognized as untrusted and marked. + +It's also possible to visualize the object dumped as bytes: + + import skops.io as sio + my_model = ... + sio.visualize(sio.dumps(my_model)) + +There are various options to customize the output. By default, the security of +nodes is color coded if `rich `_ is +installed, otherwise they all have the same color. To install ``rich``, run: + +.. code:: + + python -m pip install rich + +or, when installing skops, install it like this: + + python -m pip install skops[rich] + +To disable colors, even if ``rich`` is installed, pass ``use_colors=False`` to +:func:`skops.io.visualize`. + +It's also possible to change what colors are being used, e.g. by passing +``visualize(..., color_safe="cyan")`` to change the color for trusted nodes from +green to cyan. The ``rich`` docs list the `supported standard colors +`_. + +Note that the visualization feature is intended to help understand the structure +of the object, e.g. what attributes are identified as untrusted. It is not a +replacement for a proper security check. In particular, just because an object's +visualization looks innocent does *not* mean you can just call `sio.load(, +trusted=True)` on this object -- only pass the types you really trust to the +``trusted`` argument. Supported libraries ------------------- -Skops intends to support all of ````scikit-learn**, that is, not only its +Skops intends to support all of **scikit-learn**, that is, not only its estimators, but also other classes like cross validation splitters. Furthermore, most types from **numpy** and **scipy** should be supported, such as (sparse) arrays, dtypes, random generators, and ufuncs. diff --git a/setup.py b/setup.py index 658861c8..3cc6aa6a 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,7 @@ def setup_package(): extras_require={ "docs": min_deps.tag_to_packages["docs"], "tests": min_deps.tag_to_packages["tests"], + "rich": min_deps.tag_to_packages["rich"], }, include_package_data=True, ) diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index 06b65524..d66b554f 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -6,7 +6,8 @@ # 'build' and 'install' is included to have structured metadata for CI. # It will NOT be included in setup's extras_require # The values are (version_spec, comma separated tags, condition) -# tags can be: 'build', 'install', 'docs', 'examples', 'tests', 'benchmark' +# tags can be: 'build', 'install', 'docs', 'examples', 'tests', 'benchmark', +# 'rich' # example: # "tomli": ("1.1.0", "install", "python_full_version < '3.11.0a7'"), dependent_packages = { @@ -34,14 +35,14 @@ # TODO: remove condition when catboost supports python 3.11 "catboost": ("1.0", "tests", "python_version < '3.11'"), "fairlearn": ("0.7.0", "docs, tests", None), - "rich": ("12", "tests", None), + "rich": ("12", "tests, rich", None), } # create inverse mapping for setuptools tag_to_packages: dict = { extra: [] - for extra in ["build", "install", "docs", "examples", "tests", "benchmark"] + for extra in ["build", "install", "docs", "examples", "tests", "benchmark", "rich"] } for package, (min_version, extras, condition) in dependent_packages.items(): for extra in extras.split(", "): diff --git a/skops/conftest.py b/skops/conftest.py index 9ee0a4db..93430f83 100644 --- a/skops/conftest.py +++ b/skops/conftest.py @@ -43,3 +43,16 @@ def mock_import(name, *args, **kwargs): yield import matplotlib # noqa + + +@pytest.fixture +def rich_not_installed(): + orig_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "rich": + raise ImportError + return orig_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + yield diff --git a/skops/io/__init__.py b/skops/io/__init__.py index d381d49d..c60c99c6 100644 --- a/skops/io/__init__.py +++ b/skops/io/__init__.py @@ -1,4 +1,4 @@ from ._persist import dump, dumps, get_untrusted_types, load, loads -from ._visualize import visualize_tree +from ._visualize import visualize -__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types", "visualize_tree"] +__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types", "visualize"] diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index ca1b9da8..953345e7 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -7,9 +7,8 @@ from typing import Any, Callable, Iterator, Literal from zipfile import ZipFile -from ..utils.importutils import import_or_raise from ._audit import Node, get_tree -from ._general import FunctionNode, JsonNode +from ._general import FunctionNode, JsonNode, ListNode from ._numpy import NdArrayNode from ._scipy import SparseMatrixNode from ._utils import LoadContext @@ -17,11 +16,25 @@ @dataclass class NodeInfo: + """Information pertinent for visualizatoin, extracted from ``Node``s. + + This class contains all information necessary for visualizing nodes. This + way, we can have separate functions for: + + - visiting nodes and determining their safety + - visualizing the nodes + + The visualization function will only receive the ``NodeInfo`` and does not + have to concern itself with how to discover children or determine safety. + + """ + level: int key: str # the key to the node val: str # the value of the node is_self_safe: bool # whether this specific node is safe is_safe: bool # whether this node and all of its children are safe + is_last: bool # whether this is the last child of parent node def _check_visibility( @@ -61,7 +74,13 @@ def _get_node_label( """ # note: when changing the arguments to this function, also update the - # docstring of visualize_tree! + # docstring of visualize! + + if use_colors: + try: + import rich # noqa + except ImportError: + use_colors = False # add tag if necessary node_val = node.val @@ -69,7 +88,7 @@ def _get_node_label( if tag: node_val += f" {tag}" - # colorize if so desired + # colorize if so desired and if rich is installed if use_colors: if node.is_safe: color = color_safe @@ -88,29 +107,31 @@ def pretty_print_tree( **kwargs, ) -> None: # This function loops through the flattened nodes of the tree and creates a - # rich Tree based on the node information. Rich can then create a pretty - # visualization of said tree. - rich = import_or_raise("rich", "pretty printing the object") - from rich.tree import Tree + # tree visualization based on the node information. If rich is installed, + # nodes can be colored. - nodes = list(nodes_iter) - if not nodes: # empty tree, hmm - return + print_ = print + try: + import rich + + # use rich for printing if available + print_ = rich.print # type: ignore + except ImportError: + pass - # start with root node, it is always visible - node = nodes.pop(0) - node_label = _get_node_label(node, **kwargs) - cur_level = node.level # should be 0 - root = tree = Tree(f"{node.key}: {node_label}") - trace = [tree] # trace keeps track of what is the current node to add to + # start with root node + node = next(nodes_iter) + label = _get_node_label(node, **kwargs) + print_(f"{node.key}: {label}") + prev_level = node.level # should be 0 + prefix = "" - while nodes: - node = nodes.pop(0) + for node in nodes_iter: visible = _check_visibility(node.is_self_safe, node.is_safe, show=show) if not visible: continue - level_diff = cur_level - node.level + level_diff = prev_level - node.level if level_diff < -1: # this would mean it is a "(great-)grandchild" node raise ValueError( @@ -121,39 +142,44 @@ def pretty_print_tree( # Level diff of -1 means that this node is a child of the previous node. # E.g. if the current level if 4 and the previous level was 3, the - # current node is a child node of the previous one. Since the previous - # node is already the last node in the trace, there is nothing more that - # needs to be done. Therefore, for a diff of -1, we don't pop from the - # trace. + # current node is a child node of the previous one. Since the prefix for + # a child node was already added, there is nothing more left to do. + for _ in range(level_diff + 1): - # If the level diff is greater than -1, it means that the current - # node is not the child of the last node, but of a node higher up. - # E.g. if the current level is 2 and previous level was 3, it means - # that we should move up 2 layers of nesting, therefore, we pop - # 3-2+1 = 2 levels. - trace.pop(-1) + # This loop is entered if the current node is at the same level as, + # or higher than, the previous node. This means the prefix has to be + # truncated according to the level difference. E.g. if the current + # level is 2 and previous level was 3, it means that we should move + # up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times. + prefix = prefix[:-4] + + print_(prefix, end="") + if node.is_last: + print_("└──", end="") + prefix += " " + else: + print_("├──", end="") + prefix += "│ " - # add tag if necessary - node_label = _get_node_label(node, **kwargs) - text = f"{node.key}: {node_label}" - tree = trace[-1] - trace.append(tree.add(text)) - cur_level = node.level + label = _get_node_label(node, **kwargs) + print_(f" {node.key}: {label}") - rich.print(root) + prev_level = node.level def walk_tree( node: Node | dict[str, Node] | list[Node], node_name: str = "root", level: int = 0, + is_last: bool = False, ) -> Iterator[NodeInfo]: """Visit all nodes of the tree and yield their important attributes. This function visits all nodes of the object tree and determines: - level: how nested the node is - - key: the key of the node, e.g. the key of a dict. + - key: the key of the node. E.g. if the node is an attribute of an object, + the key would be the name of the attribute. - val: the value of the node, e.g. builtins.list - safety: whether it, and its children, are trusted @@ -176,6 +202,9 @@ def walk_tree( level: int (default=0) The current level of nesting. + is_last: bool (default=False) + Whether this is the last node among its sibling nodes. + Yields ------ :class:`~NodeInfo`: @@ -185,34 +214,40 @@ def walk_tree( # key_types is not helpful, as it is artificially added by skops to # circumvent the fact that json only allows keys to be strings. It is not # useful to the user and adds a lot of noise, thus skip key_types. - # TODO: check that no funny business is going on in key types if node_name == "key_types": - return + if isinstance(node, ListNode) and node.is_safe(): + return + raise ValueError( + "An invalid 'key_types' node was encountered, please report the issue " + "here: https://github.com/skops-dev/skops/issues" + ) - # COMPOSITE TYPES: CHECK ALL ITEMS if isinstance(node, dict): - for key, val in node.items(): + num_nodes = len(node) + for i, (key, val) in enumerate(node.items(), start=1): yield from walk_tree( val, node_name=key, level=level, + is_last=i == num_nodes, ) return if isinstance(node, (list, tuple)): - # shouldn't be tuple, but check just to be sure - for val in node: + num_nodes = len(node) + for i, val in enumerate(node, start=1): yield from walk_tree( val, node_name=node_name, level=level, + is_last=i == num_nodes, ) return # NO MATCH: RAISE ERROR if not isinstance(node, Node): raise TypeError( - f"Cannot deal with {type(node)}, please report the issue here " + f"Cannot deal with {type(node)}, please report the issue here: " "https://github.com/skops-dev/skops/issues" ) @@ -227,24 +262,16 @@ def walk_tree( val=node.format(), is_self_safe=node.is_self_safe(), is_safe=node.is_safe(), + is_last=is_last, ) # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT # TODO: For better security, we should check the schema if we return early, - # otherwise something nefarious could be hidden inside. - if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type != "json"): + # otherwise something nefarious could be hidden inside (however, if there + # is, the node should be marked as unsafe) + if isinstance(node, (NdArrayNode, SparseMatrixNode, FunctionNode, JsonNode)): return - if isinstance(node, (NdArrayNode, SparseMatrixNode)) and (node.type == "json"): - return - - if isinstance(node, FunctionNode): - return - - if isinstance(node, JsonNode): - pass - - # RECURSE yield from walk_tree( node.children, node_name=node_name, @@ -252,7 +279,7 @@ def walk_tree( ) -def visualize_tree( +def visualize( file: Path | str | bytes, show: Literal["all", "untrusted", "trusted"] = "all", sink: Callable[..., None] = pretty_print_tree, @@ -292,7 +319,7 @@ def visualize_tree( - the safety of the node and its children The ``show`` argument is explained above. Any additional ``kwargs`` - passed to ``visualize_tree`` will also be passed to ``sink``. + passed to ``visualize`` will also be passed to ``sink``. The default sink is :func:`~pretty_print_tree`, which takes these additional parameters: @@ -301,14 +328,15 @@ def visualize_tree( tag) - tag_unsafe: The tag used to mark untrusted nodes (default="[UNSAFE]") - - use_colors: Whether to colorize the nodes (default=True) + - use_colors: Whether to colorize the nodes (default=True). Colors + requires the ``rich`` package to be installed. - color_safe: Color to use for trusted nodes (default="green") - color_unsafe: Color to use for untrusted nodes (default="red") - color_child_unsafe: Color to use for nodes that are trusted but that have untrusted child ndoes (default="yellow") So if you don't want to have colored output, just pass - ``use_colors=False`` to ``visualize_tree``. The colors themselves, such + ``use_colors=False`` to ``visualize``. The colors themselves, such as "red" and "green", refer to the standard colors used by ``rich``. """ diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 9faf44e5..62aa51b0 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -47,7 +47,7 @@ def unsafe_function(x): @pytest.mark.parametrize("show", ["all", "trusted", "untrusted"]) def test_print_simple(self, simple, show, capsys): file = sio.dumps(simple) - sio.visualize_tree(file, show=show) + sio.visualize(file, show=show) # Output is the same for "all" and "trusted" because all nodes are # trusted. Colors are not recorded by capsys. @@ -70,7 +70,7 @@ def test_print_simple(self, simple, show, capsys): def test_print_pipeline(self, pipeline, capsys): file = sio.dumps(pipeline) - sio.visualize_tree(file) + sio.visualize(file) # no point in checking the whole output with > 120 lines expected_start = [ @@ -97,7 +97,7 @@ def test_unsafe_nodes(self, pipeline): def sink(nodes_iter, *args, **kwargs): nodes.extend(nodes_iter) - sio.visualize_tree(file, sink=sink) + sio.visualize(file, sink=sink) nodes_self_unsafe = [node for node in nodes if not node.is_self_safe] nodes_unsafe = [node for node in nodes if not node.is_safe] @@ -129,7 +129,7 @@ def my_sink(nodes_iter, show, **sink_kwargs): assert sink_kwargs[key] == val file = sio.dumps(simple) - sio.visualize_tree(file, sink=my_sink, **kwargs) + sio.visualize(file, sink=my_sink, **kwargs) def test_custom_tags(self, simple, capsys): class UnsafeType: @@ -138,7 +138,7 @@ class UnsafeType: simple.copy = UnsafeType file = sio.dumps(simple) - sio.visualize_tree(file, tag_safe="NICE", tag_unsafe="OHNO") + sio.visualize(file, tag_safe="NICE", tag_unsafe="OHNO") expected = [ "root: sklearn.preprocessing._data.MinMaxScaler NICE", "└── attrs: builtins.dict NICE", @@ -156,36 +156,93 @@ class UnsafeType: assert stdout.strip() == "\n".join(expected) def test_custom_colors(self, simple): - # Colors are not recorded by capsys, so we cannot use it + # test that custom colors are used in node representation, requires rich + # to work + pytest.importorskip("rich") + class UnsafeType: pass simple.copy = UnsafeType - file = sio.dumps(simple) + + # Colors are not recorded by capsys, so we cannot use it and must mock + # printing mock_print = Mock() with patch("rich.print", mock_print): - sio.visualize_tree( + sio.visualize( file, color_safe="black", color_unsafe="cyan", color_child_unsafe="orange3", ) - assert mock_print.call_count == 1 + mock_print.assert_called() - tree = mock_print.call_args_list[0].args[0] + calls = mock_print.call_args_list # The root node is indirectly unsafe through child - assert "[orange3]" in tree.label - # feature_range is safe - assert "[black]" in tree.children[0].children[0].label - # copy is unsafe - assert "[cyan]" in tree.children[0].children[1].label + assert ( + calls[0].args[0] + == "root: [orange3]sklearn.preprocessing._data.MinMaxScaler" + ) + # 'feature_range' is safe + assert calls[6].args[0] == " feature_range: [black]builtins.tuple" + # 'copy' is unsafe + assert calls[15].args[0] == " copy: [cyan]test_visualize.UnsafeType [UNSAFE]" + + @pytest.mark.usefixtures("rich_not_installed") + def test_no_colors_if_rich_not_installed(self, simple): + # this test is similar to the previous one, except that we test that the + # colors are *not* used if rich is not installed + file = sio.dumps(simple) + + # don't use capsys, because it wouldn't capture the colors, thus need to + # use mock + mock_print = Mock() + with patch("builtins.print", mock_print): + sio.visualize( + file, + color_safe="black", + color_unsafe="cyan", + color_child_unsafe="orange3", + ) + mock_print.assert_called() + + # check that none of the colors are being used + colors = ("black", "cyan", "orange3") + for call in mock_print.call_args_list: + for color in colors: + assert color not in call.args[0] + + def test_no_colors_if_use_colors_false(self, simple): + # this test is similar to the previous one, except that we test that the + # colors are *not* used, even if rich is installed, when passing + # use_colors=False + file = sio.dumps(simple) + + # don't use capsys, because it wouldn't capture the colors, thus need to + # use mock + mock_print = Mock() + with patch("rich.print", mock_print): + sio.visualize( + file, + color_safe="black", + color_unsafe="cyan", + color_child_unsafe="orange3", + use_colors=False, + ) + mock_print.assert_called() + + # check that none of the colors are being used + colors = ("black", "cyan", "orange3") + for call in mock_print.call_args_list: + for color in colors: + assert color not in call.args[0] def test_from_file(self, simple, tmp_path, capsys): f_name = tmp_path / "estimator.skops" sio.dump(simple, f_name) - sio.visualize_tree(f_name) + sio.visualize(f_name) expected = [ "root: sklearn.preprocessing._data.MinMaxScaler",