From b2c038c5febabbfeeb536d8664b2aee1a43993f8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 5 Mar 2025 22:03:28 +0100 Subject: [PATCH 1/7] FIX use rich.tree instead of rich.print --- skops/io/_visualize.py | 111 +++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 54 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 7c4965d3..2adc0122 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -74,9 +74,9 @@ def _get_node_label( tag_unsafe: str = "[UNSAFE]", use_colors: bool = True, # use rich for coloring - color_safe: str = "green", - color_unsafe: str = "red", - color_child_unsafe: str = "yellow", + color_safe: str = "cyan", + color_unsafe: str = "orange1", + color_child_unsafe: str = "magenta", ): """Determine the label of a node. @@ -101,40 +101,24 @@ def _get_node_label( # colorize if so desired and if rich is installed if use_colors: if node.is_safe: - color = color_safe + style = f"bold {color_safe}" elif node.is_self_safe: - color = color_child_unsafe + style = f"bold {color_child_unsafe}" else: - color = color_unsafe - node_val = f"[{color}]{node_val}" + style = f"bold {color_unsafe}" + node_val = f"[{style}]{node_val}[/{style}]" return node_val -def pretty_print_tree( - nodes_iter: Iterator[NodeInfo], - show: Literal["all", "untrusted", "trusted"], - **kwargs, -) -> None: - # This function loops through the flattened nodes of the tree and creates a - # tree visualization based on the node information. If rich is installed, - # nodes can be colored. - - print_ = print - try: - import rich - - # use rich for printing if available - print_ = rich.print # type: ignore - except ImportError: - pass - +def _traverse_tree(nodes_iter, show, **kwargs): + """Common tree traversal logic used by both rich and fallback display methods.""" # start with root node node = next(nodes_iter) label = _get_node_label(node, **kwargs) - print_(f"{node.key}: {label}") + yield node, label, 0, True # node, label, level, is_first_node + prev_level = node.level # should be 0 - prefix = "" for node in nodes_iter: visible = _check_visibility(node.is_self_safe, node.is_safe, show=show) @@ -150,33 +134,52 @@ def pretty_print_tree( "report the issue here: https://github.com/skops-dev/skops/issues" ) - # Level diff of -1 means that this node is a child of the previous node. - # E.g. if the current level if 4 and the previous level was 3, the - # current node is a child node of the previous one. Since the prefix for - # a child node was already added, there is nothing more left to do. - - for _ in range(level_diff + 1): - # This loop is entered if the current node is at the same level as, - # or higher than, the previous node. This means the prefix has to be - # truncated according to the level difference. E.g. if the current - # level is 2 and previous level was 3, it means that we should move - # up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times. - prefix = prefix[:-4] - - print_(prefix, end="") - if node.is_last: - print_("└──", end="") - prefix += " " - else: - print_("├──", end="") - prefix += "│ " - label = _get_node_label(node, **kwargs) - print_(f" {node.key}: {label}") - + yield node, label, level_diff, False prev_level = node.level +def pretty_print_tree(nodes_iter, show, **kwargs): + try: + from rich.tree import Tree + from rich.console import Console + console = Console() + + for node, label, level_diff, is_first_node in _traverse_tree(nodes_iter, show): + if is_first_node: + tree = Tree(f"{node.key}: {label}", guide_style="gray50") + trees = {0: tree} + continue + + parent_level = node.level - 1 + parent_tree = trees[parent_level] + current_tree = parent_tree.add(f"{node.key}: {label}") + trees[node.level] = current_tree + + console.print(tree) + + except ImportError: + prefix = "" + for node, label, level_diff, is_first_node in _traverse_tree(nodes_iter, show): + if is_first_node: + print(f"{node.key}: {label}") + continue + + # Update prefix based on level difference + for _ in range(level_diff + 1): + prefix = prefix[:-4] + + print(prefix, end="") + if node.is_last: + print("└──", end="") + prefix += " " + else: + print("├──", end="") + prefix += "│ " + + print(f" {node.key}: {label}") + + def walk_tree( node: VALID_NODE_CHILD_TYPES | dict[str, VALID_NODE_CHILD_TYPES], node_name: str = "root", @@ -342,14 +345,14 @@ def visualize( (default="[UNSAFE]") - use_colors: Whether to colorize the nodes (default=True). Colors requires the ``rich`` package to be installed. - - color_safe: Color to use for trusted nodes (default="green") - - color_unsafe: Color to use for untrusted nodes (default="red") + - color_safe: Color to use for trusted nodes (default="orange1") + - color_unsafe: Color to use for untrusted nodes (default="cyan") - color_child_unsafe: Color to use for nodes that are trusted but - that have untrusted child ndoes (default="yellow") + that have untrusted child ndoes (default="magenta") So if you don't want to have colored output, just pass ``use_colors=False`` to ``visualize``. The colors themselves, such - as "red" and "green", refer to the standard colors used by ``rich``. + as "orange1" and "cyan", refer to the standard colors used by ``rich``. """ if isinstance(file, bytes): From 25c4919d6af3a0c13c298eefd2435ec12b07d500 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 5 Mar 2025 22:10:20 +0100 Subject: [PATCH 2/7] iter --- skops/io/_visualize.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 2adc0122..a5438c56 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -165,8 +165,12 @@ def pretty_print_tree(nodes_iter, show, **kwargs): print(f"{node.key}: {label}") continue - # Update prefix based on level difference for _ in range(level_diff + 1): + # This loop is entered if the current node is at the same level as, + # or higher than, the previous node. This means the prefix has to be + # truncated according to the level difference. E.g. if the current + # level is 2 and previous level was 3, it means that we should move + # up 2 layers of nesting, therefore, we trunce 3-2+1 = 2 times. prefix = prefix[:-4] print(prefix, end="") From 4f147322d5f1581aee253b6f8c3290eab5081930 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 5 Mar 2025 22:10:45 +0100 Subject: [PATCH 3/7] iter --- skops/io/_visualize.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index a5438c56..f5eedb74 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -165,6 +165,10 @@ def pretty_print_tree(nodes_iter, show, **kwargs): print(f"{node.key}: {label}") continue + # Level diff of -1 means that this node is a child of the previous node. + # E.g. if the current level if 4 and the previous level was 3, the + # current node is a child node of the previous one. Since the prefix for + # a child node was already added, there is nothing more left to do. for _ in range(level_diff + 1): # This loop is entered if the current node is at the same level as, # or higher than, the previous node. This means the prefix has to be From 718e282cb5b3cf95afc91943f643036032830975 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 6 Mar 2025 08:00:11 +0100 Subject: [PATCH 4/7] linter --- skops/io/_visualize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index f5eedb74..cd1321af 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -141,8 +141,9 @@ def _traverse_tree(nodes_iter, show, **kwargs): def pretty_print_tree(nodes_iter, show, **kwargs): try: - from rich.tree import Tree from rich.console import Console + from rich.tree import Tree + console = Console() for node, label, level_diff, is_first_node in _traverse_tree(nodes_iter, show): From 0c3a110c03ed8254bd2d46e1033978fcdd08c17d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 6 Mar 2025 09:15:16 +0100 Subject: [PATCH 5/7] few fixes --- skops/conftest.py | 2 +- skops/io/tests/test_visualize.py | 28 +++++++++++++++++++--------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/skops/conftest.py b/skops/conftest.py index 93430f83..27cbe375 100644 --- a/skops/conftest.py +++ b/skops/conftest.py @@ -50,7 +50,7 @@ def rich_not_installed(): orig_import = builtins.__import__ def mock_import(name, *args, **kwargs): - if name == "rich": + if name == "rich" or name.startswith("rich."): raise ImportError return orig_import(name, *args, **kwargs) diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 85ee55dd..7e7918f7 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -19,6 +19,18 @@ from skops.io import get_untrusted_types +def get_Tree_str(tree): + """Get the string representation of a tree.""" + from io import StringIO + + from rich.console import Console + + with StringIO() as output: + console = Console(file=output) + console.print(tree) + return output.getvalue() + + class TestVisualizeTree: @pytest.fixture def simple(self): @@ -177,7 +189,7 @@ class UnsafeType: # Colors are not recorded by capsys, so we cannot use it and must mock # printing mock_print = Mock() - with patch("rich.print", mock_print): + with patch("rich.console.Console.print", mock_print): sio.visualize( file, color_safe="black", @@ -188,15 +200,13 @@ class UnsafeType: mock_print.assert_called() calls = mock_print.call_args_list + tree_repr = get_Tree_str(calls[0].args[0]) # The root node is indirectly unsafe through child - assert ( - calls[0].args[0] - == "root: [orange3]sklearn.preprocessing._data.MinMaxScaler" - ) + assert "root: [orange3]sklearn.preprocessing._data.MinMaxScaler" in tree_repr # 'feature_range' is safe - assert calls[6].args[0] == " feature_range: [black]builtins.tuple" + assert " feature_range: [black]builtins.tuple" in tree_repr # 'copy' is unsafe - assert calls[15].args[0] == " copy: [cyan]test_visualize.UnsafeType [UNSAFE]" + assert " copy: [cyan]test_visualize.UnsafeType [UNSAFE]" in tree_repr @pytest.mark.usefixtures("rich_not_installed") def test_no_colors_if_rich_not_installed(self, simple): @@ -231,7 +241,7 @@ def test_no_colors_if_use_colors_false(self, simple): # don't use capsys, because it wouldn't capture the colors, thus need to # use mock mock_print = Mock() - with patch("rich.print", mock_print): + with patch("rich.console.Console.print", mock_print): sio.visualize( file, color_safe="black", @@ -245,7 +255,7 @@ def test_no_colors_if_use_colors_false(self, simple): colors = ("black", "cyan", "orange3") for call in mock_print.call_args_list: for color in colors: - assert color not in call.args[0] + assert color not in get_Tree_str(call.args[0]) def test_from_file(self, simple, tmp_path, capsys): f_name = tmp_path / "estimator.skops" From 8e20c06acf58fc4930f9083dc407aa9625772945 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 7 Mar 2025 19:43:55 +0100 Subject: [PATCH 6/7] fix tags and adapt test color rich --- skops/io/_visualize.py | 14 +++++++++----- skops/io/tests/test_visualize.py | 26 ++++++++++++++------------ 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index cd1321af..afd606c0 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -101,11 +101,11 @@ def _get_node_label( # colorize if so desired and if rich is installed if use_colors: if node.is_safe: - style = f"bold {color_safe}" + style = f"{color_safe}" elif node.is_self_safe: - style = f"bold {color_child_unsafe}" + style = f"{color_child_unsafe}" else: - style = f"bold {color_unsafe}" + style = f"{color_unsafe}" node_val = f"[{style}]{node_val}[/{style}]" return node_val @@ -146,7 +146,9 @@ def pretty_print_tree(nodes_iter, show, **kwargs): console = Console() - for node, label, level_diff, is_first_node in _traverse_tree(nodes_iter, show): + for node, label, level_diff, is_first_node in _traverse_tree( + nodes_iter, show, **kwargs + ): if is_first_node: tree = Tree(f"{node.key}: {label}", guide_style="gray50") trees = {0: tree} @@ -161,7 +163,9 @@ def pretty_print_tree(nodes_iter, show, **kwargs): except ImportError: prefix = "" - for node, label, level_diff, is_first_node in _traverse_tree(nodes_iter, show): + for node, label, level_diff, is_first_node in _traverse_tree( + nodes_iter, show, **kwargs + ): if is_first_node: print(f"{node.key}: {label}") continue diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 7e7918f7..a88b4bb7 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -20,15 +20,15 @@ def get_Tree_str(tree): - """Get the string representation of a tree.""" - from io import StringIO - + """Get the string representation of a tree in Rich's markup syntax.""" from rich.console import Console + from rich.text import Text - with StringIO() as output: - console = Console(file=output) + console = Console() + with console.capture() as capture: console.print(tree) - return output.getvalue() + text = Text.from_ansi(capture.get()) + return text.markup class TestVisualizeTree: @@ -202,11 +202,14 @@ class UnsafeType: calls = mock_print.call_args_list tree_repr = get_Tree_str(calls[0].args[0]) # The root node is indirectly unsafe through child - assert "root: [orange3]sklearn.preprocessing._data.MinMaxScaler" in tree_repr + # orange3 is color(5) + assert "root: [color(5)]sklearn.preprocessing._data.MinMaxScaler" in tree_repr # 'feature_range' is safe - assert " feature_range: [black]builtins.tuple" in tree_repr + # black is color(6) + assert "feature_range: [color(6)]builtins.tuple" in tree_repr # 'copy' is unsafe - assert " copy: [cyan]test_visualize.UnsafeType [UNSAFE]" in tree_repr + # cyan is color(214) + assert "copy: [color(214)]test_visualize.UnsafeType [UNSAFE]" in tree_repr @pytest.mark.usefixtures("rich_not_installed") def test_no_colors_if_rich_not_installed(self, simple): @@ -227,10 +230,9 @@ def test_no_colors_if_rich_not_installed(self, simple): mock_print.assert_called() # check that none of the colors are being used - colors = ("black", "cyan", "orange3") for call in mock_print.call_args_list: - for color in colors: - assert color not in call.args[0] + # check for the color markers + assert "[color(" not in call.args[0] def test_no_colors_if_use_colors_false(self, simple): # this test is similar to the previous one, except that we test that the From d01e876e1701c33209b04c938ce076a7404fcaa3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 8 Mar 2025 11:54:52 +0100 Subject: [PATCH 7/7] fix --- skops/io/_visualize.py | 6 +++++- skops/io/tests/test_visualize.py | 16 +++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index afd606c0..3c3c2c39 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -139,7 +139,11 @@ def _traverse_tree(nodes_iter, show, **kwargs): prev_level = node.level -def pretty_print_tree(nodes_iter, show, **kwargs): +def pretty_print_tree( + nodes_iter: Iterator[NodeInfo], + show: Literal["all", "untrusted", "trusted"], + **kwargs: Any, +) -> None: try: from rich.console import Console from rich.tree import Tree diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index a88b4bb7..469b0e0b 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -24,7 +24,9 @@ def get_Tree_str(tree): from rich.console import Console from rich.text import Text - console = Console() + # force the color system to check that we have the right colors across + # platforms and terminals + console = Console(force_terminal=True, color_system="truecolor") with console.capture() as capture: console.print(tree) text = Text.from_ansi(capture.get()) @@ -202,14 +204,14 @@ class UnsafeType: calls = mock_print.call_args_list tree_repr = get_Tree_str(calls[0].args[0]) # The root node is indirectly unsafe through child - # orange3 is color(5) - assert "root: [color(5)]sklearn.preprocessing._data.MinMaxScaler" in tree_repr + # orange3 is color(172) + assert "root: [color(172)]sklearn.preprocessing._data.MinMaxScaler" in tree_repr # 'feature_range' is safe - # black is color(6) - assert "feature_range: [color(6)]builtins.tuple" in tree_repr + # black is color(0) + assert "feature_range: [color(0)]builtins.tuple" in tree_repr # 'copy' is unsafe - # cyan is color(214) - assert "copy: [color(214)]test_visualize.UnsafeType [UNSAFE]" in tree_repr + # cyan is color(6) + assert "copy: [color(6)]test_visualize.UnsafeType [UNSAFE]" in tree_repr @pytest.mark.usefixtures("rich_not_installed") def test_no_colors_if_rich_not_installed(self, simple):