From cde9c5927748ee1c1c2d3226a108ee090812ec72 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 27 Mar 2023 17:43:03 +0200 Subject: [PATCH 1/6] [WIP] Refactor trusted argument usage See discussion: https://discord.com/channels/879548962464493619/1047505653603774584/1085550033207836783 Right now, all nodes are always initialized with trusted=False, so they are not "aware" of what is trusted and what isn't (which means the argument might as well not exist). This refactor would pass the actual trusted argument the user provides to the nodes. As a consequence, the constructed tree is aware of what nodes are considered trusted or not. Some unit tests are still failing, so this is WIP. --- skops/io/_audit.py | 32 ++++++++++++++------------------ skops/io/_general.py | 36 ++++++++++++++++++++++++------------ skops/io/_numpy.py | 19 ++++++++++++------- skops/io/_persist.py | 13 ++++++------- skops/io/_sklearn.py | 10 ++++++---- skops/io/_visualize.py | 12 +++++++++--- skops/io/tests/test_audit.py | 16 ++++++++++------ 7 files changed, 81 insertions(+), 57 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 067b13c5..a48d2440 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -9,7 +9,7 @@ from ._utils import LoadContext, get_module, get_type_paths from .exceptions import UntrustedTypesFoundException -NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {} +NODE_TYPE_MAPPING: dict[tuple[str, int], Type[Node]] = {} def check_type( @@ -41,7 +41,7 @@ def check_type( return module_name + "." + type_name in trusted -def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None: +def audit_tree(tree: Node) -> None: """Audit a tree of nodes. A tree is safe if it only contains trusted types. Audit is skipped if @@ -52,24 +52,12 @@ def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None: tree : skops.io._dispatch.Node The tree to audit. - trusted : True, or list of str - If ``True``, the tree is considered safe. Otherwise trusted has to be - a list of trusted types names. - - An entry in the list is typically of the form - ``skops.io._utils.get_module(obj) + "." + obj.__class__.__name__``. - Raises ------ UntrustedTypesFoundException If the tree contains an untrusted type. """ - if trusted is True: - return - unsafe = tree.get_unsafe_set() - if isinstance(trusted, (list, set)): - unsafe -= set(trusted) if unsafe: raise UntrustedTypesFoundException(unsafe) @@ -191,6 +179,8 @@ def _get_trusted( ) -> Literal[True] | list[str]: """Return a trusted list, or True. + TODO + If ``trusted`` is ``False``, we return the ``default``, otherwise the ``trusted`` value is used. @@ -205,7 +195,7 @@ def _get_trusted( return get_type_paths(default) # otherwise, we trust the given list - return get_type_paths(trusted) + return get_type_paths(trusted) + get_type_paths(default) def is_self_safe(self) -> bool: """True only if the node's type is considered safe. @@ -316,10 +306,14 @@ def _construct(self): return self.cached.construct() -NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode # type: ignore +NODE_TYPE_MAPPING[("CachedNode", PROTOCOL)] = CachedNode -def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: +def get_tree( + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, +) -> Node: """Get the tree of nodes. This function returns the root node of the tree of nodes. The tree is @@ -338,6 +332,8 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: load_context : LoadContext The context of the loading process. + trusted : TODO + Returns ------- loaded_tree : Node @@ -374,5 +370,5 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: f"protocol {protocol}." ) - loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore + loaded_tree = node_cls(state, load_context, trusted=trusted) return loaded_tree diff --git a/skops/io/_general.py b/skops/io/_general.py index 5a5f3479..df4ea858 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -59,9 +59,9 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [dict]) self.children = { - "key_types": get_tree(state["key_types"], load_context), + "key_types": get_tree(state["key_types"], load_context, trusted=trusted), "content": { - key: get_tree(value, load_context) + key: get_tree(value, load_context, trusted=trusted) for key, value in state["content"].items() }, } @@ -96,7 +96,10 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [list]) self.children = { - "content": [get_tree(value, load_context) for value in state["content"]] + "content": [ + get_tree(value, load_context, trusted=trusted) + for value in state["content"] + ] } def _construct(self): @@ -125,7 +128,10 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [set]) self.children = { - "content": [get_tree(value, load_context) for value in state["content"]] + "content": [ + get_tree(value, load_context, trusted=trusted) + for value in state["content"] + ] } def _construct(self): @@ -154,7 +160,10 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [tuple]) self.children = { - "content": [get_tree(value, load_context) for value in state["content"]] + "content": [ + get_tree(value, load_context, trusted=trusted) + for value in state["content"] + ] } def _construct(self): @@ -241,10 +250,12 @@ def __init__( # TODO: should we trust anything? self.trusted = self._get_trusted(trusted, []) self.children = { - "func": get_tree(state["content"]["func"], load_context), - "args": get_tree(state["content"]["args"], load_context), - "kwds": get_tree(state["content"]["kwds"], load_context), - "namespace": get_tree(state["content"]["namespace"], load_context), + "func": get_tree(state["content"]["func"], load_context, trusted=trusted), + "args": get_tree(state["content"]["args"], load_context, trusted=trusted), + "kwds": get_tree(state["content"]["kwds"], load_context, trusted=trusted), + "namespace": get_tree( + state["content"]["namespace"], load_context, trusted=trusted + ), } def _construct(self): @@ -375,7 +386,7 @@ def __init__( content = state.get("content") if content is not None: - attrs = get_tree(content, load_context) + attrs = get_tree(content, load_context, trusted=trusted) else: attrs = None @@ -432,7 +443,7 @@ def __init__( ) -> None: super().__init__(state, load_context, trusted) self.children = { - "obj": get_tree(state["content"]["obj"], load_context), + "obj": get_tree(state["content"]["obj"], load_context, trusted=trusted), "func": state["content"]["func"], } # TODO: what do we trust? @@ -458,6 +469,7 @@ def __init__( super().__init__(state, load_context, trusted) self.content = state["content"] self.children = {} + self.trusted = self._get_trusted(trusted, [int, float, str, type(None)]) def is_safe(self) -> bool: # JsonNode is always considered safe. @@ -552,7 +564,7 @@ def __init__( ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, []) - self.children["attrs"] = get_tree(state["attrs"], load_context) + self.children["attrs"] = get_tree(state["attrs"], load_context, trusted=trusted) def _construct(self): op = getattr(operator, self.class_name) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 1f2cef82..e6662c78 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -67,10 +67,10 @@ def __init__( } elif self.type == "json": self.children = { - "content": [ # type: ignore - get_tree(o, load_context) for o in state["content"] # type: ignore + "content": [ + get_tree(o, load_context, trusted=trusted) for o in state["content"] ], - "shape": get_tree(state["shape"], load_context), + "shape": get_tree(state["shape"], load_context, trusted=trusted), } else: raise ValueError(f"Unknown type {self.type}.") @@ -126,8 +126,8 @@ def __init__( super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, [np.ma.MaskedArray]) self.children = { - "data": get_tree(state["content"]["data"], load_context), - "mask": get_tree(state["content"]["mask"], load_context), + "data": get_tree(state["content"]["data"], load_context, trusted=trusted), + "mask": get_tree(state["content"]["mask"], load_context, trusted=trusted), } def _construct(self): @@ -155,7 +155,10 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - self.children = {"content": get_tree(state["content"], load_context)} + # TODO + self.children = { + "content": get_tree(state["content"], load_context, trusted=trusted) + } self.trusted = self._get_trusted(trusted, [np.random.RandomState]) def _construct(self): @@ -218,7 +221,9 @@ def __init__( trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) - self.children = {"content": get_tree(state["content"], load_context)} + self.children = { + "content": get_tree(state["content"], load_context, trusted=trusted) + } # TODO: what should we trust? self.trusted = self._get_trusted(trusted, []) diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 12796e05..bf41f615 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -127,8 +127,8 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: with ZipFile(file, "r") as input_zip: schema = json.loads(input_zip.read("schema.json")) load_context = LoadContext(src=input_zip, protocol=schema["protocol"]) - tree = get_tree(schema, load_context) - audit_tree(tree, trusted) + tree = get_tree(schema, load_context, trusted=trusted) + audit_tree(tree) instance = tree.construct() return instance @@ -167,8 +167,8 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) - tree = get_tree(schema, load_context) - audit_tree(tree, trusted) + tree = get_tree(schema, load_context, trusted=trusted) + audit_tree(tree) instance = tree.construct() return instance @@ -210,9 +210,8 @@ def get_untrusted_types( with ZipFile(content, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree( - schema, load_context=LoadContext(src=zip_file, protocol=schema["protocol"]) - ) + load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) + tree = get_tree(schema, load_context=load_context, trusted=False) untrusted_types = tree.get_unsafe_set() return sorted(untrusted_types) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 4d302267..b5e4d638 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -102,8 +102,8 @@ def __init__( super().__init__(state, load_context, trusted) reduce = state["__reduce__"] self.children = { - "attrs": get_tree(state["content"], load_context), - "args": get_tree(reduce["args"], load_context), + "attrs": get_tree(state["content"], load_context, trusted=trusted), + "args": get_tree(reduce["args"], load_context, trusted=trusted), "constructor": constructor, } @@ -210,9 +210,11 @@ def __init__( get_module(_DictWithDeprecatedKeysNode) + "._DictWithDeprecatedKeys" ] self.children = { - "main": get_tree(state["content"]["main"], load_context), + "main": get_tree(state["content"]["main"], load_context, trusted=trusted), "_deprecated_key_to_new_key": get_tree( - state["content"]["_deprecated_key_to_new_key"], load_context + state["content"]["_deprecated_key_to_new_key"], + load_context, + trusted=trusted, ), } diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 75269bf0..d9fce9a6 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -4,7 +4,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Iterator, Literal +from typing import Any, Callable, Iterator, Literal, Sequence from zipfile import ZipFile from ._audit import Node, get_tree @@ -282,6 +282,7 @@ def walk_tree( def visualize( file: Path | str | bytes, show: Literal["all", "untrusted", "trusted"] = "all", + trusted: bool | Sequence[str] = False, sink: Callable[..., None] = pretty_print_tree, **kwargs: Any, ) -> None: @@ -307,8 +308,13 @@ def visualize( show: "all" or "untrusted" or "trusted" Whether to print all nodes, only untrusted nodes, or only trusted nodes. - sink: function (default=:func:`~pretty_print_tree`) + trusted: bool, or list of str, default=False + If ``True``, all nodes will be treated as trusted. If ``False``, only + default types are trusted. If a list of strings, where those strongs + describe the trusted types, these types are trusted on top of the + default trusted types. + sink: function (default=:func:`~pretty_print_tree`) This function should take at least two arguments, an iterator of :class:`~NodeInfo` instances and an indicator of what to show. The ``NodeInfo`` contains the information about the node, namely: @@ -348,7 +354,7 @@ def visualize( with zf as zip_file: schema = json.loads(zip_file.read("schema.json")) load_context = LoadContext(src=zip_file, protocol=schema["protocol"]) - tree = get_tree(schema, load_context=load_context) + tree = get_tree(schema, load_context=load_context, trusted=trusted) nodes = walk_tree(tree) # TODO: it would be nice to print html representation if inside a notebook diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 8bd68867..0f3a8b2b 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -41,31 +41,35 @@ def test_check_type(module_name, type_name, trusted, expected): def test_audit_tree_untrusted(): var = {"a": CustomType(1), 2: CustomType(2)} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None, -1), trusted=False) + load_context = LoadContext(None, -1) + + node = DictNode(state, load_context, trusted=False) with pytest.raises( TypeError, match=re.escape( "Untrusted types found in the file: ['test_audit.CustomType']." ), ): - audit_tree(node, trusted=False) + audit_tree(node) # there shouldn't be an error with trusted=True - audit_tree(node, trusted=True) + node = DictNode(state, LoadContext(None, -1), trusted=True) + audit_tree(node) untrusted_list = get_untrusted_types(data=dumps(var)) assert untrusted_list == ["test_audit.CustomType"] # passing the type would fix it. - audit_tree(node, trusted=untrusted_list) + node = DictNode(state, LoadContext(None, -1), trusted=untrusted_list) + audit_tree(node) def test_audit_tree_defaults(): # test that the default types are trusted var = {"a": 1, 2: "b"} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None, -1), trusted=False) - audit_tree(node, trusted=[]) + node = DictNode(state, LoadContext(None, -1), trusted=[]) + audit_tree(node) @pytest.mark.parametrize( From 91a4e74e525af34ffb68afedd52b9a6708658e3f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 3 Apr 2023 11:59:11 +0200 Subject: [PATCH 2/6] Some clean up --- skops/io/_audit.py | 10 +++++++++- skops/io/_general.py | 4 ++-- skops/io/_persist.py | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index a48d2440..3c6c0d8a 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -57,6 +57,9 @@ def audit_tree(tree: Node) -> None: UntrustedTypesFoundException If the tree contains an untrusted type. """ + if tree.trusted is True: + return + unsafe = tree.get_unsafe_set() if unsafe: raise UntrustedTypesFoundException(unsafe) @@ -332,7 +335,12 @@ def get_tree( load_context : LoadContext The context of the loading process. - trusted : TODO + trusted : bool, or list of str, default=False + If ``True``, the object will be loaded without any security checks. If + ``False``, the object will be loaded only if there are only trusted + objects in the dumped file. If a list of strings, the object will be + loaded only if there are only trusted objects and objects of types + listed in ``trusted`` in the dumped file. Returns ------- diff --git a/skops/io/_general.py b/skops/io/_general.py index df4ea858..22340f57 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -444,14 +444,14 @@ def __init__( super().__init__(state, load_context, trusted) self.children = { "obj": get_tree(state["content"]["obj"], load_context, trusted=trusted), - "func": state["content"]["func"], } + self.func_name = state["content"]["func"] # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) def _construct(self): loaded_obj = self.children["obj"].construct() - method = getattr(loaded_obj, self.children["func"]) + method = getattr(loaded_obj, self.func_name) return method diff --git a/skops/io/_persist.py b/skops/io/_persist.py index f06bf73b..507b61d9 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -116,7 +116,7 @@ def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: ``False``, the object will be loaded only if there are only trusted objects in the dumped file. If a list of strings, the object will be loaded only if there are only trusted objects and objects of types - listed in ``trusted`` are in the dumped file. + listed in ``trusted`` in the dumped file. Returns ------- @@ -154,7 +154,7 @@ def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: ``False``, the object will be loaded only if there are only trusted objects in the dumped file. If a list of strings, the object will be loaded only if there are only trusted objects and objects of types - listed in ``trusted`` are in the dumped file. + listed in ``trusted`` in the dumped file. Returns ------- From a3350e6f5996ce17a0068e19dacd1ef1d13e33c3 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 3 Apr 2023 12:09:24 +0200 Subject: [PATCH 3/6] Update docstring, revert unnecessary change --- skops/io/_audit.py | 10 +++++----- skops/io/tests/test_audit.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 3c6c0d8a..4eeabbf9 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -182,12 +182,12 @@ def _get_trusted( ) -> Literal[True] | list[str]: """Return a trusted list, or True. - TODO - - If ``trusted`` is ``False``, we return the ``default``, otherwise the - ``trusted`` value is used. + If ``trusted`` is ``False``, we return the ``default``. If a list of + types are being passed, those types, as well as default trusted types, + are returned. This is a convenience method called by child classes. + """ if trusted is True: # if trusted is True, we trust the node @@ -197,7 +197,7 @@ def _get_trusted( # if trusted is False, we only trust the defaults return get_type_paths(default) - # otherwise, we trust the given list + # otherwise, we trust the given list and default trusted types return get_type_paths(trusted) + get_type_paths(default) def is_self_safe(self) -> bool: diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index 0f3a8b2b..cece09ce 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -68,7 +68,7 @@ def test_audit_tree_defaults(): # test that the default types are trusted var = {"a": 1, 2: "b"} state = dict_get_state(var, SaveContext(None, 0, {})) - node = DictNode(state, LoadContext(None, -1), trusted=[]) + node = DictNode(state, LoadContext(None, -1), trusted=False) audit_tree(node) From 63ae2bff176c1ab25149c8ffb698f58a312a5023 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 3 Apr 2023 12:23:44 +0200 Subject: [PATCH 4/6] Add a test for visualize with trusted argument --- skops/io/tests/test_visualize.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/skops/io/tests/test_visualize.py b/skops/io/tests/test_visualize.py index 91dab864..33a00436 100644 --- a/skops/io/tests/test_visualize.py +++ b/skops/io/tests/test_visualize.py @@ -114,6 +114,19 @@ def sink(nodes_iter, *args, **kwargs): assert len(nodes_unsafe) > 2 assert any("FunctionTransformer" in node.val for node in nodes_unsafe) + @pytest.mark.parametrize( + "trusted", [True, ["numpy.int64", "test_visualize.unsafe_function"]] + ) + def test_all_nodes_trusted(self, pipeline, trusted, capsys): + # The pipeline contains untrusted type(s), but if we pass trusted=True, + # it is not considered untrusted anymore + # TODO: remove numpy.int64 from trusted once it's trusted by default + file = sio.dumps(pipeline) + sio.visualize(file, show="untrusted", trusted=trusted) + expected = "root: sklearn.pipeline.Pipeline" + stdout, _ = capsys.readouterr() + assert stdout.strip() == expected + @pytest.mark.parametrize( "kwargs", [ From 967c4735305721b7cfb2d923571842cf164155c1 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 3 Apr 2023 17:57:07 +0200 Subject: [PATCH 5/6] Address reviewer comments - make all but 1st arg to visualize kw only - don't set default for trusted in get_tree - use PRIMITIVE_TYPE_NAMES --- skops/io/_audit.py | 4 ++-- skops/io/_general.py | 2 +- skops/io/_numpy.py | 2 +- skops/io/_visualize.py | 1 + skops/io/tests/test_audit.py | 4 +++- skops/io/tests/test_persist.py | 2 +- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 4eeabbf9..78b2baf6 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -315,7 +315,7 @@ def _construct(self): def get_tree( state: dict[str, Any], load_context: LoadContext, - trusted: bool | Sequence[str] = False, + trusted: bool | Sequence[str], ) -> Node: """Get the tree of nodes. @@ -335,7 +335,7 @@ def get_tree( load_context : LoadContext The context of the loading process. - trusted : bool, or list of str, default=False + trusted : bool, or list of str If ``True``, the object will be loaded without any security checks. If ``False``, the object will be loaded only if there are only trusted objects in the dumped file. If a list of strings, the object will be diff --git a/skops/io/_general.py b/skops/io/_general.py index 22340f57..546b1509 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -469,7 +469,7 @@ def __init__( super().__init__(state, load_context, trusted) self.content = state["content"] self.children = {} - self.trusted = self._get_trusted(trusted, [int, float, str, type(None)]) + self.trusted = self._get_trusted(trusted, PRIMITIVE_TYPE_NAMES) def is_safe(self) -> bool: # JsonNode is always considered safe. diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index a62294c4..1a0aa413 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -188,7 +188,7 @@ def __init__( super().__init__(state, load_context, trusted) self.children = { "bit_generator_state": get_tree( - state["content"]["bit_generator"], load_context + state["content"]["bit_generator"], load_context, trusted=trusted ) } self.trusted = self._get_trusted(trusted, [np.random.Generator]) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index d9fce9a6..2173515a 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -281,6 +281,7 @@ def walk_tree( def visualize( file: Path | str | bytes, + *, show: Literal["all", "untrusted", "trusted"] = "all", trusted: bool | Sequence[str] = False, sink: Callable[..., None] = pretty_print_tree, diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index cece09ce..7cf4ffc6 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -101,7 +101,9 @@ def test_list_safety(values, is_safe): with ZipFile(io.BytesIO(content), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - tree = get_tree(schema, load_context=LoadContext(src=zip_file, protocol=-1)) + tree = get_tree( + schema, load_context=LoadContext(src=zip_file, protocol=-1), trusted=False + ) assert tree.is_safe() == is_safe diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 1724a4c1..bacd2552 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -664,7 +664,7 @@ def test_get_tree_unknown_type_error_msg(): state["__loader__"] = "this_get_tree_does_not_exist" msg = "Can't find loader this_get_tree_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_tree(state, LoadContext(None, -1)) + get_tree(state, LoadContext(None, -1), trusted=False) class _BoundMethodHolder: From 634601b4bafef507fc6009435321bdc4583bd182 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 6 Apr 2023 14:34:20 +0200 Subject: [PATCH 6/6] Revert change to MethodNode --- skops/io/_general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 546b1509..b6b00d11 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -444,14 +444,14 @@ def __init__( super().__init__(state, load_context, trusted) self.children = { "obj": get_tree(state["content"]["obj"], load_context, trusted=trusted), + "func": state["content"]["func"], } - self.func_name = state["content"]["func"] # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) def _construct(self): loaded_obj = self.children["obj"].construct() - method = getattr(loaded_obj, self.func_name) + method = getattr(loaded_obj, self.children["func"]) return method