Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skops/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def rich_not_installed():
orig_import = builtins.__import__

def mock_import(name, *args, **kwargs):
if name == "rich":
if name == "rich" or name.startswith("rich."):
raise ImportError
return orig_import(name, *args, **kwargs)

Expand Down
128 changes: 74 additions & 54 deletions skops/io/_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def _get_node_label(
tag_unsafe: str = "[UNSAFE]",
use_colors: bool = True,
# use rich for coloring
color_safe: str = "green",
color_unsafe: str = "red",
color_child_unsafe: str = "yellow",
color_safe: str = "cyan",
color_unsafe: str = "orange1",
color_child_unsafe: str = "magenta",
):
"""Determine the label of a node.

Expand All @@ -101,40 +101,24 @@ def _get_node_label(
# colorize if so desired and if rich is installed
if use_colors:
if node.is_safe:
color = color_safe
style = f"{color_safe}"
elif node.is_self_safe:
color = color_child_unsafe
style = f"{color_child_unsafe}"
else:
color = color_unsafe
node_val = f"[{color}]{node_val}"
style = f"{color_unsafe}"
node_val = f"[{style}]{node_val}[/{style}]"

return node_val


def pretty_print_tree(
nodes_iter: Iterator[NodeInfo],
show: Literal["all", "untrusted", "trusted"],
**kwargs,
) -> None:
# This function loops through the flattened nodes of the tree and creates a
# tree visualization based on the node information. If rich is installed,
# nodes can be colored.

print_ = print
try:
import rich

# use rich for printing if available
print_ = rich.print # type: ignore
except ImportError:
pass

def _traverse_tree(nodes_iter, show, **kwargs):
"""Common tree traversal logic used by both rich and fallback display methods."""
# start with root node
node = next(nodes_iter)
label = _get_node_label(node, **kwargs)
print_(f"{node.key}: {label}")
yield node, label, 0, True # node, label, level, is_first_node

prev_level = node.level # should be 0
prefix = ""

for node in nodes_iter:
visible = _check_visibility(node.is_self_safe, node.is_safe, show=show)
Expand All @@ -150,33 +134,69 @@ def pretty_print_tree(
"report the issue here: https://github.com/skops-dev/skops/issues"
)

# Level diff of -1 means that this node is a child of the previous node.
# E.g. if the current level if 4 and the previous level was 3, the
# current node is a child node of the previous one. Since the prefix for
# a child node was already added, there is nothing more left to do.

for _ in range(level_diff + 1):
# This loop is entered if the current node is at the same level as,
# or higher than, the previous node. This means the prefix has to be
# truncated according to the level difference. E.g. if the current
# level is 2 and previous level was 3, it means that we should move
# up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times.
prefix = prefix[:-4]

print_(prefix, end="")
if node.is_last:
print_("└──", end="")
prefix += " "
else:
print_("├──", end="")
prefix += "│ "

label = _get_node_label(node, **kwargs)
print_(f" {node.key}: {label}")

yield node, label, level_diff, False
prev_level = node.level


def pretty_print_tree(
nodes_iter: Iterator[NodeInfo],
show: Literal["all", "untrusted", "trusted"],
**kwargs: Any,
) -> None:
try:
from rich.console import Console
from rich.tree import Tree

console = Console()

for node, label, level_diff, is_first_node in _traverse_tree(
nodes_iter, show, **kwargs
):
if is_first_node:
tree = Tree(f"{node.key}: {label}", guide_style="gray50")
trees = {0: tree}
continue

parent_level = node.level - 1
parent_tree = trees[parent_level]
current_tree = parent_tree.add(f"{node.key}: {label}")
trees[node.level] = current_tree

console.print(tree)

except ImportError:
prefix = ""
for node, label, level_diff, is_first_node in _traverse_tree(
nodes_iter, show, **kwargs
):
if is_first_node:
print(f"{node.key}: {label}")
continue

# Level diff of -1 means that this node is a child of the previous node.
# E.g. if the current level if 4 and the previous level was 3, the
# current node is a child node of the previous one. Since the prefix for
# a child node was already added, there is nothing more left to do.
for _ in range(level_diff + 1):
# This loop is entered if the current node is at the same level as,
# or higher than, the previous node. This means the prefix has to be
# truncated according to the level difference. E.g. if the current
# level is 2 and previous level was 3, it means that we should move
# up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times.
prefix = prefix[:-4]

print(prefix, end="")
if node.is_last:
print("└──", end="")
prefix += " "
else:
print("├──", end="")
prefix += "│ "

print(f" {node.key}: {label}")


def walk_tree(
node: VALID_NODE_CHILD_TYPES | dict[str, VALID_NODE_CHILD_TYPES],
node_name: str = "root",
Expand Down Expand Up @@ -342,14 +362,14 @@ def visualize(
(default="[UNSAFE]")
- use_colors: Whether to colorize the nodes (default=True). Colors
requires the ``rich`` package to be installed.
- color_safe: Color to use for trusted nodes (default="green")
- color_unsafe: Color to use for untrusted nodes (default="red")
- color_safe: Color to use for trusted nodes (default="orange1")
- color_unsafe: Color to use for untrusted nodes (default="cyan")
- color_child_unsafe: Color to use for nodes that are trusted but
that have untrusted child ndoes (default="yellow")
that have untrusted child ndoes (default="magenta")

So if you don't want to have colored output, just pass
``use_colors=False`` to ``visualize``. The colors themselves, such
as "red" and "green", refer to the standard colors used by ``rich``.
as "orange1" and "cyan", refer to the standard colors used by ``rich``.

"""
if isinstance(file, bytes):
Expand Down
38 changes: 26 additions & 12 deletions skops/io/tests/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
from skops.io import get_untrusted_types


def get_Tree_str(tree):
"""Get the string representation of a tree in Rich's markup syntax."""
from rich.console import Console
from rich.text import Text

# force the color system to check that we have the right colors across
# platforms and terminals
console = Console(force_terminal=True, color_system="truecolor")
with console.capture() as capture:
console.print(tree)
text = Text.from_ansi(capture.get())
return text.markup


class TestVisualizeTree:
@pytest.fixture
def simple(self):
Expand Down Expand Up @@ -177,7 +191,7 @@ class UnsafeType:
# Colors are not recorded by capsys, so we cannot use it and must mock
# printing
mock_print = Mock()
with patch("rich.print", mock_print):
with patch("rich.console.Console.print", mock_print):
sio.visualize(
file,
color_safe="black",
Expand All @@ -188,15 +202,16 @@ class UnsafeType:
mock_print.assert_called()

calls = mock_print.call_args_list
tree_repr = get_Tree_str(calls[0].args[0])
# The root node is indirectly unsafe through child
assert (
calls[0].args[0]
== "root: [orange3]sklearn.preprocessing._data.MinMaxScaler"
)
# orange3 is color(172)
assert "root: [color(172)]sklearn.preprocessing._data.MinMaxScaler" in tree_repr
# 'feature_range' is safe
assert calls[6].args[0] == " feature_range: [black]builtins.tuple"
# black is color(0)
assert "feature_range: [color(0)]builtins.tuple" in tree_repr
# 'copy' is unsafe
assert calls[15].args[0] == " copy: [cyan]test_visualize.UnsafeType [UNSAFE]"
# cyan is color(6)
assert "copy: [color(6)]test_visualize.UnsafeType [UNSAFE]" in tree_repr

@pytest.mark.usefixtures("rich_not_installed")
def test_no_colors_if_rich_not_installed(self, simple):
Expand All @@ -217,10 +232,9 @@ def test_no_colors_if_rich_not_installed(self, simple):
mock_print.assert_called()

# check that none of the colors are being used
colors = ("black", "cyan", "orange3")
for call in mock_print.call_args_list:
for color in colors:
assert color not in call.args[0]
# check for the color markers
assert "[color(" not in call.args[0]

def test_no_colors_if_use_colors_false(self, simple):
# this test is similar to the previous one, except that we test that the
Expand All @@ -231,7 +245,7 @@ def test_no_colors_if_use_colors_false(self, simple):
# don't use capsys, because it wouldn't capture the colors, thus need to
# use mock
mock_print = Mock()
with patch("rich.print", mock_print):
with patch("rich.console.Console.print", mock_print):
sio.visualize(
file,
color_safe="black",
Expand All @@ -245,7 +259,7 @@ def test_no_colors_if_use_colors_false(self, simple):
colors = ("black", "cyan", "orange3")
for call in mock_print.call_args_list:
for color in colors:
assert color not in call.args[0]
assert color not in get_Tree_str(call.args[0])

def test_from_file(self, simple, tmp_path, capsys):
f_name = tmp_path / "estimator.skops"
Expand Down