diff --git a/skops/conftest.py b/skops/conftest.py index 93430f83..27cbe375 100644 --- a/skops/conftest.py +++ b/skops/conftest.py @@ -50,7 +50,7 @@ def rich_not_installed(): orig_import = builtins.__import__ def mock_import(name, *args, **kwargs): - if name == "rich": + if name == "rich" or name.startswith("rich."): raise ImportError return orig_import(name, *args, **kwargs) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 7c4965d3..3c3c2c39 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -74,9 +74,9 @@ def _get_node_label( tag_unsafe: str = "[UNSAFE]", use_colors: bool = True, # use rich for coloring - color_safe: str = "green", - color_unsafe: str = "red", - color_child_unsafe: str = "yellow", + color_safe: str = "cyan", + color_unsafe: str = "orange1", + color_child_unsafe: str = "magenta", ): """Determine the label of a node. @@ -101,40 +101,24 @@ def _get_node_label( # colorize if so desired and if rich is installed if use_colors: if node.is_safe: - color = color_safe + style = f"{color_safe}" elif node.is_self_safe: - color = color_child_unsafe + style = f"{color_child_unsafe}" else: - color = color_unsafe - node_val = f"[{color}]{node_val}" + style = f"{color_unsafe}" + node_val = f"[{style}]{node_val}[/{style}]" return node_val -def pretty_print_tree( - nodes_iter: Iterator[NodeInfo], - show: Literal["all", "untrusted", "trusted"], - **kwargs, -) -> None: - # This function loops through the flattened nodes of the tree and creates a - # tree visualization based on the node information. If rich is installed, - # nodes can be colored. - - print_ = print - try: - import rich - - # use rich for printing if available - print_ = rich.print # type: ignore - except ImportError: - pass - +def _traverse_tree(nodes_iter, show, **kwargs): + """Common tree traversal logic used by both rich and fallback display methods.""" # start with root node node = next(nodes_iter) label = _get_node_label(node, **kwargs) - print_(f"{node.key}: {label}") + yield node, label, 0, True # node, label, level, is_first_node + prev_level = node.level # should be 0 - prefix = "" for node in nodes_iter: visible = _check_visibility(node.is_self_safe, node.is_safe, show=show) @@ -150,33 +134,69 @@ def pretty_print_tree( "report the issue here: https://github.com/skops-dev/skops/issues" ) - # Level diff of -1 means that this node is a child of the previous node. - # E.g. if the current level if 4 and the previous level was 3, the - # current node is a child node of the previous one. Since the prefix for - # a child node was already added, there is nothing more left to do. - - for _ in range(level_diff + 1): - # This loop is entered if the current node is at the same level as, - # or higher than, the previous node. This means the prefix has to be - # truncated according to the level difference. E.g. if the current - # level is 2 and previous level was 3, it means that we should move - # up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times. - prefix = prefix[:-4] - - print_(prefix, end="") - if node.is_last: - print_("└──", end="") - prefix += " " - else: - print_("├──", end="") - prefix += "│ " - label = _get_node_label(node, **kwargs) - print_(f" {node.key}: {label}") - + yield node, label, level_diff, False prev_level = node.level +def pretty_print_tree( + nodes_iter: Iterator[NodeInfo], + show: Literal["all", "untrusted", "trusted"], + **kwargs: Any, +) -> None: + try: + from rich.console import Console + from rich.tree import Tree + + console = Console() + + for node, label, level_diff, is_first_node in _traverse_tree( + nodes_iter, show, **kwargs + ): + if is_first_node: + tree = Tree(f"{node.key}: {label}", guide_style="gray50") + trees = {0: tree} + continue + + parent_level = node.level - 1 + parent_tree = trees[parent_level] + current_tree = parent_tree.add(f"{node.key}: {label}") + trees[node.level] = current_tree + + console.print(tree) + + except ImportError: + prefix = "" + for node, label, level_diff, is_first_node in _traverse_tree( + nodes_iter, show, **kwargs + ): + if is_first_node: + print(f"{node.key}: {label}") + continue + + # Level diff of -1 means that this node is a child of the previous node. + # E.g. if the current level if 4 and the previous level was 3, the + # current node is a child node of the previous one. Since the prefix for + # a child node was already added, there is nothing more left to do. + for _ in range(level_diff + 1): + # This loop is entered if the current node is at the same level as, + # or higher than, the previous node. This means the prefix has to be + # truncated according to the level difference. E.g. if the current + # level is 2 and previous level was 3, it means that we should move + # up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times. + prefix = prefix[:-4] + + print(prefix, end="") + if node.is_last: + print("└──", end="") + prefix += " " + else: + print("├──", end="") + prefix += "│ " + + print(f" {node.key}: {label}") + + def walk_tree( node: VALID_NODE_CHILD_TYPES | dict[str, VALID_NODE_CHILD_TYPES], node_name: str = "root", @@ -342,14 +362,14 @@ def visualize( (default="[UNSAFE]") - use_colors: Whether to colorize the nodes (default=True). Colors requires the ``rich`` package to be installed. - - color_safe: Color to use for trusted nodes (default="green") - - color_unsafe: Color to use for untrusted nodes (default="red") + - color_safe: Color to use for trusted nodes (default="orange1") + - color_unsafe: Color to use for untrusted nodes (default="cyan") - color_child_unsafe: Color to use for nodes that are trusted but - that have untrusted child ndoes (default="yellow") + that have untrusted child ndoes (default="magenta") So if you don't want to have colored output, just pass ``use_colors=False`` to ``visualize``. The colors themselves, such - as "red" and "green", refer to the standard colors used by ``rich``. + as "orange1" and "cyan", refer to the standard colors used by ``rich``. """ if isinstance(file, bytes): diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 85ee55dd..469b0e0b 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -19,6 +19,20 @@ from skops.io import get_untrusted_types +def get_Tree_str(tree): + """Get the string representation of a tree in Rich's markup syntax.""" + from rich.console import Console + from rich.text import Text + + # force the color system to check that we have the right colors across + # platforms and terminals + console = Console(force_terminal=True, color_system="truecolor") + with console.capture() as capture: + console.print(tree) + text = Text.from_ansi(capture.get()) + return text.markup + + class TestVisualizeTree: @pytest.fixture def simple(self): @@ -177,7 +191,7 @@ class UnsafeType: # Colors are not recorded by capsys, so we cannot use it and must mock # printing mock_print = Mock() - with patch("rich.print", mock_print): + with patch("rich.console.Console.print", mock_print): sio.visualize( file, color_safe="black", @@ -188,15 +202,16 @@ class UnsafeType: mock_print.assert_called() calls = mock_print.call_args_list + tree_repr = get_Tree_str(calls[0].args[0]) # The root node is indirectly unsafe through child - assert ( - calls[0].args[0] - == "root: [orange3]sklearn.preprocessing._data.MinMaxScaler" - ) + # orange3 is color(172) + assert "root: [color(172)]sklearn.preprocessing._data.MinMaxScaler" in tree_repr # 'feature_range' is safe - assert calls[6].args[0] == " feature_range: [black]builtins.tuple" + # black is color(0) + assert "feature_range: [color(0)]builtins.tuple" in tree_repr # 'copy' is unsafe - assert calls[15].args[0] == " copy: [cyan]test_visualize.UnsafeType [UNSAFE]" + # cyan is color(6) + assert "copy: [color(6)]test_visualize.UnsafeType [UNSAFE]" in tree_repr @pytest.mark.usefixtures("rich_not_installed") def test_no_colors_if_rich_not_installed(self, simple): @@ -217,10 +232,9 @@ def test_no_colors_if_rich_not_installed(self, simple): mock_print.assert_called() # check that none of the colors are being used - colors = ("black", "cyan", "orange3") for call in mock_print.call_args_list: - for color in colors: - assert color not in call.args[0] + # check for the color markers + assert "[color(" not in call.args[0] def test_no_colors_if_use_colors_false(self, simple): # this test is similar to the previous one, except that we test that the @@ -231,7 +245,7 @@ def test_no_colors_if_use_colors_false(self, simple): # don't use capsys, because it wouldn't capture the colors, thus need to # use mock mock_print = Mock() - with patch("rich.print", mock_print): + with patch("rich.console.Console.print", mock_print): sio.visualize( file, color_safe="black", @@ -245,7 +259,7 @@ def test_no_colors_if_use_colors_false(self, simple): colors = ("black", "cyan", "orange3") for call in mock_print.call_args_list: for color in colors: - assert color not in call.args[0] + assert color not in get_Tree_str(call.args[0]) def test_from_file(self, simple, tmp_path, capsys): f_name = tmp_path / "estimator.skops"