diff --git a/skops/io/_audit.py b/skops/io/_audit.py index d2426473..5ae9fc28 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -8,7 +8,7 @@ from ._utils import LoadContext, get_module, get_type_paths from .exceptions import UntrustedTypesFoundException -NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {} +NODE_TYPE_MAPPING: dict[tuple[str, int], Type[Node]] = {} VALID_NODE_CHILD_TYPES = Optional[ Union["Node", List["Node"], Dict[str, "Node"], Type, str, io.BytesIO] ] @@ -43,7 +43,7 @@ def check_type( return module_name + "." + type_name in trusted -def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None: +def audit_tree(tree: Node) -> None: """Audit a tree of nodes. A tree is safe if it only contains trusted types. Audit is skipped if @@ -54,24 +54,15 @@ def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None: tree : skops.io._dispatch.Node The tree to audit. - trusted : True, or list of str - If ``True``, the tree is considered safe. Otherwise trusted has to be - a list of trusted types names. - - An entry in the list is typically of the form - ``skops.io._utils.get_module(obj) + "." + obj.__class__.__name__``. - Raises ------ UntrustedTypesFoundException If the tree contains an untrusted type. """ - if trusted is True: + if tree.trusted is True: return unsafe = tree.get_unsafe_set() - if isinstance(trusted, (list, set)): - unsafe -= set(trusted) if unsafe: raise UntrustedTypesFoundException(unsafe) @@ -193,10 +184,12 @@ def _get_trusted( ) -> Literal[True] | list[str]: """Return a trusted list, or True. - If ``trusted`` is ``False``, we return the ``default``, otherwise the - ``trusted`` value is used. + If ``trusted`` is ``False``, we return the ``default``. If a list of + types are being passed, those types, as well as default trusted types, + are returned. This is a convenience method called by child classes. + """ if trusted is True: # if trusted is True, we trust the node @@ -206,8 +199,8 @@ def _get_trusted( # if trusted is False, we only trust the defaults return get_type_paths(default) - # otherwise, we trust the given list - return get_type_paths(trusted) + # otherwise, we trust the given list and default trusted types + return get_type_paths(trusted) + get_type_paths(default) def is_self_safe(self) -> bool: """True only if the node's type is considered safe. @@ -314,10 +307,14 @@ def _construct(self): return self.cached.construct() -NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode # type: ignore +NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode -def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: +def get_tree( + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str], +) -> Node: """Get the tree of nodes. This function returns the root node of the tree of nodes. The tree is @@ -336,6 +333,13 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: load_context : LoadContext The context of the loading process. + trusted : bool, or list of str + If ``True``, the object will be loaded without any security checks. If + ``False``, the object will be loaded only if there are only trusted + objects in the dumped file. If a list of strings, the object will be + loaded only if there are only trusted objects and objects of types + listed in ``trusted`` in the dumped file. + Returns ------- loaded_tree : Node @@ -372,5 +376,5 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: f"protocol {protocol}." ) - loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore + loaded_tree = node_cls(state, load_context, trusted=trusted) return loaded_tree diff --git a/skops/io/_general.py b/skops/io/_general.py index 5a5f3479..b6b00d11 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -59,9 +59,9 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [dict]) self.children = { - "key_types": get_tree(state["key_types"], load_context), + "key_types": get_tree(state["key_types"], load_context, trusted=trusted), "content": { - key: get_tree(value, load_context) + key: get_tree(value, load_context, trusted=trusted) for key, value in state["content"].items() }, } @@ -96,7 +96,10 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [list]) self.children = { - "content": [get_tree(value, load_context) for value in state["content"]] + "content": [ + get_tree(value, load_context, trusted=trusted) + for value in state["content"] + ] } def _construct(self): @@ -125,7 +128,10 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [set]) self.children = { - "content": [get_tree(value, load_context) for value in state["content"]] + "content": [ + get_tree(value, load_context, trusted=trusted) + for value in state["content"] + ] } def _construct(self): @@ -154,7 +160,10 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [tuple]) self.children = { - "content": [get_tree(value, load_context) for value in state["content"]] + "content": [ + get_tree(value, load_context, trusted=trusted) + for value in state["content"] + ] } def _construct(self): @@ -241,10 +250,12 @@ def __init__( # TODO: should we trust anything? self.trusted = self._get_trusted(trusted, []) self.children = { - "func": get_tree(state["content"]["func"], load_context), - "args": get_tree(state["content"]["args"], load_context), - "kwds": get_tree(state["content"]["kwds"], load_context), - "namespace": get_tree(state["content"]["namespace"], load_context), + "func": get_tree(state["content"]["func"], load_context, trusted=trusted), + "args": get_tree(state["content"]["args"], load_context, trusted=trusted), + "kwds": get_tree(state["content"]["kwds"], load_context, trusted=trusted), + "namespace": get_tree( + state["content"]["namespace"], load_context, trusted=trusted + ), } def _construct(self): @@ -375,7 +386,7 @@ def __init__( content = state.get("content") if content is not None: - attrs = get_tree(content, load_context) + attrs = get_tree(content, load_context, trusted=trusted) else: attrs = None @@ -432,7 +443,7 @@ def __init__( ) -> None: super().__init__(state, load_context, trusted) self.children = { - "obj": get_tree(state["content"]["obj"], load_context), + "obj": get_tree(state["content"]["obj"], load_context, trusted=trusted), "func": state["content"]["func"], } # TODO: what do we trust? @@ -458,6 +469,7 @@ def __init__( super().__init__(state, load_context, trusted) self.content = state["content"] self.children = {} + self.trusted = self._get_trusted(trusted, PRIMITIVE_TYPE_NAMES) def is_safe(self) -> bool: # JsonNode is always considered safe. @@ -552,7 +564,7 @@ def __init__( ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, []) - self.children["attrs"] = get_tree(state["attrs"], load_context) + self.children["attrs"] = get_tree(state["attrs"], load_context, trusted=trusted) def _construct(self): op = getattr(operator, self.class_name) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index d5243f15..1a0aa413 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -67,10 +67,10 @@ def __init__( } elif self.type == "json": self.children = { - "content": [ # type: ignore - get_tree(o, load_context) for o in state["content"] # type: ignore + "content": [ + get_tree(o, load_context, trusted=trusted) for o in state["content"] ], - "shape": get_tree(state["shape"], load_context), + "shape": get_tree(state["shape"], load_context, trusted=trusted), } else: raise ValueError(f"Unknown type {self.type}.") @@ -126,8 +126,8 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [np.ma.MaskedArray]) self.children = { - "data": get_tree(state["content"]["data"], load_context), - "mask": get_tree(state["content"]["mask"], load_context), + "data": get_tree(state["content"]["data"], load_context, trusted=trusted), + "mask": get_tree(state["content"]["mask"], load_context, trusted=trusted), } def _construct(self): @@ -155,7 +155,10 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - self.children = {"content": get_tree(state["content"], load_context)} + # TODO + self.children = { + "content": get_tree(state["content"], load_context, trusted=trusted) + } self.trusted = self._get_trusted(trusted, [np.random.RandomState]) def _construct(self): @@ -185,7 +188,7 @@ def __init__( super().__init__(state, load_context, trusted) self.children = { "bit_generator_state": get_tree( - state["content"]["bit_generator"], load_context + state["content"]["bit_generator"], load_context, trusted=trusted ) } self.trusted = self._get_trusted(trusted, [np.random.Generator]) @@ -224,7 +227,9 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - self.children = {"content": get_tree(state["content"], load_context)} + self.children = { + "content": get_tree(state["content"], load_context, trusted=trusted) + } # TODO: what should we trust? self.trusted = self._get_trusted(trusted, []) diff --git a/skops/io/_persist.py b/skops/io/_persist.py index a90cb1fc..507b61d9 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -116,7 +116,7 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: ``False``, the object will be loaded only if there are only trusted objects in the dumped file. If a list of strings, the object will be loaded only if there are only trusted objects and objects of types - listed in ``trusted`` are in the dumped file. + listed in ``trusted`` in the dumped file. Returns ------- @@ -127,8 +127,8 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: with ZipFile(file, "r") as input_zip: schema = json.loads(input_zip.read("schema.json")) load_context = LoadContext(src=input_zip, protocol=schema["protocol"]) - tree = get_tree(schema, load_context) - audit_tree(tree, trusted) + tree = get_tree(schema, load_context, trusted=trusted) + audit_tree(tree) instance = tree.construct() return instance @@ -154,7 +154,7 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: ``False``, the object will be loaded only if there are only trusted objects in the dumped file. If a list of strings, the object will be loaded only if there are only trusted objects and objects of types - listed in ``trusted`` are in the dumped file. + listed in ``trusted`` in the dumped file. Returns ------- @@ -167,8 +167,8 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) - tree = get_tree(schema, load_context) - audit_tree(tree, trusted) + tree = get_tree(schema, load_context, trusted=trusted) + audit_tree(tree) instance = tree.construct() return instance @@ -210,9 +210,8 @@ def get_untrusted_types( with ZipFile(content, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree( - schema, load_context=LoadContext(src=zip_file, protocol=schema["protocol"]) - ) + load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) + tree = get_tree(schema, load_context=load_context, trusted=False) untrusted_types = tree.get_unsafe_set() return sorted(untrusted_types) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index ce9c3969..37745afe 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -102,8 +102,8 @@ def __init__( super().__init__(state, load_context, trusted) reduce = state["__reduce__"] self.children = { - "attrs": get_tree(state["content"], load_context), - "args": get_tree(reduce["args"], load_context), + "attrs": get_tree(state["content"], load_context, trusted=trusted), + "args": get_tree(reduce["args"], load_context, trusted=trusted), "constructor": constructor, } @@ -210,9 +210,11 @@ def __init__( get_module(_DictWithDeprecatedKeysNode) + "._DictWithDeprecatedKeys" ] self.children = { - "main": get_tree(state["content"]["main"], load_context), + "main": get_tree(state["content"]["main"], load_context, trusted=trusted), "_deprecated_key_to_new_key": get_tree( - state["content"]["_deprecated_key_to_new_key"], load_context + state["content"]["_deprecated_key_to_new_key"], + load_context, + trusted=trusted, ), } diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 8e31c2fc..1aa3dca2 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -4,7 +4,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Iterator, Literal +from typing import Any, Callable, Iterator, Literal, Sequence from zipfile import ZipFile from ._audit import VALID_NODE_CHILD_TYPES, Node, get_tree @@ -281,7 +281,9 @@ def walk_tree( def visualize( file: Path | str | bytes, + *, show: Literal["all", "untrusted", "trusted"] = "all", + trusted: bool | Sequence[str] = False, sink: Callable[..., None] = pretty_print_tree, **kwargs: Any, ) -> None: @@ -307,8 +309,13 @@ def visualize( show: "all" or "untrusted" or "trusted" Whether to print all nodes, only untrusted nodes, or only trusted nodes. - sink: function (default=:func:`~pretty_print_tree`) + trusted: bool, or list of str, default=False + If ``True``, all nodes will be treated as trusted. If ``False``, only + default types are trusted. If a list of strings, where those strongs + describe the trusted types, these types are trusted on top of the + default trusted types. + sink: function (default=:func:`~pretty_print_tree`) This function should take at least two arguments, an iterator of :class:`~NodeInfo` instances and an indicator of what to show. The ``NodeInfo`` contains the information about the node, namely: @@ -348,7 +355,7 @@ def visualize( with zf as zip_file: schema = json.loads(zip_file.read("schema.json")) load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) - tree = get_tree(schema, load_context=load_context) + tree = get_tree(schema, load_context=load_context, trusted=trusted) nodes = walk_tree(tree) # TODO: it would be nice to print html representation if inside a notebook diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 8bd68867..7cf4ffc6 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -41,23 +41,27 @@ def test_check_type(module_name, type_name, trusted, expected): def test_audit_tree_untrusted(): var = {"a": CustomType(1), 2: CustomType(2)} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None, -1), trusted=False) + load_context = LoadContext(None, -1) + + node = DictNode(state, load_context, trusted=False) with pytest.raises( TypeError, match=re.escape( "Untrusted types found in the file: ['test_audit.CustomType']." ), ): - audit_tree(node, trusted=False) + audit_tree(node) # there shouldn't be an error with trusted=True - audit_tree(node, trusted=True) + node = DictNode(state, LoadContext(None, -1), trusted=True) + audit_tree(node) untrusted_list = get_untrusted_types(data=dumps(var)) assert untrusted_list == ["test_audit.CustomType"] # passing the type would fix it. - audit_tree(node, trusted=untrusted_list) + node = DictNode(state, LoadContext(None, -1), trusted=untrusted_list) + audit_tree(node) def test_audit_tree_defaults(): @@ -65,7 +69,7 @@ def test_audit_tree_defaults(): var = {"a": 1, 2: "b"} state = dict_get_state(var, SaveContext(None, 0, {})) node = DictNode(state, LoadContext(None, -1), trusted=False) - audit_tree(node, trusted=[]) + audit_tree(node) @pytest.mark.parametrize( @@ -97,7 +101,9 @@ def test_list_safety(values, is_safe): with ZipFile(io.BytesIO(content), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree(schema, load_context=LoadContext(src=zip_file, protocol=-1)) + tree = get_tree( + schema, load_context=LoadContext(src=zip_file, protocol=-1), trusted=False + ) assert tree.is_safe() == is_safe diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index eb7c5107..9876c590 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -671,7 +671,7 @@ def test_get_tree_unknown_type_error_msg(): state["__loader__"] = "this_get_tree_does_not_exist" msg = "Can't find loader this_get_tree_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_tree(state, LoadContext(None, -1)) + get_tree(state, LoadContext(None, -1), trusted=False) class _BoundMethodHolder: diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 91dab864..33a00436 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -114,6 +114,19 @@ def sink(nodes_iter, *args, **kwargs): assert len(nodes_unsafe) > 2 assert any("FunctionTransformer" in node.val for node in nodes_unsafe) + @pytest.mark.parametrize( + "trusted", [True, ["numpy.int64", "test_visualize.unsafe_function"]] + ) + def test_all_nodes_trusted(self, pipeline, trusted, capsys): + # The pipeline contains untrusted type(s), but if we pass trusted=True, + # it is not considered untrusted anymore + # TODO: remove numpy.int64 from trusted once it's trusted by default + file = sio.dumps(pipeline) + sio.visualize(file, show="untrusted", trusted=trusted) + expected = "root: sklearn.pipeline.Pipeline" + stdout, _ = capsys.readouterr() + assert stdout.strip() == expected + @pytest.mark.parametrize( "kwargs", [