From 227917200259fd9531954586587683807c4eed6a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 28 Nov 2022 12:55:45 +0100 Subject: [PATCH 1/8] Add light type annotation to skops.io This is in line with the rest of the code base. On top of adding those type annotations, these changes were made: 1. Moved all code from _dispatch.py to _audit.py. _dispatch.py didn't make much sense after the refactor and there were also problems with circular imports. 2. Small refactor in FunctionNode: get_unsafe_set used to call is_safe, which is the inverse of how all other Nodes do it. Changed it to be consistent. 3. Fixed a bug in ReduceNode: Exception would access self.constructor, which does not exist. 4. Add a few comments to Node.__init__ about what subclasses should implement. 5. Add _construct method to Node, raise NotImplementedError 6. Add TODO comment to CachedNode as it doesn't deal with unknown id yet. 7. gettype now raises error when type could not be determined, cannot return None anymore 8. get_untrusted_types now deals with neither data nor file being passed --- skops/io/_audit.py | 303 ++++++++++++++++++++++++++++++++- skops/io/_dispatch.py | 263 ---------------------------- skops/io/_general.py | 138 ++++++++++----- skops/io/_numpy.py | 49 ++++-- skops/io/_persist.py | 27 +-- skops/io/_scipy.py | 13 +- skops/io/_sklearn.py | 44 +++-- skops/io/_utils.py | 17 +- skops/io/tests/test_audit.py | 17 +- skops/io/tests/test_persist.py | 2 +- 10 files changed, 512 insertions(+), 361 deletions(-) delete mode 100644 skops/io/_dispatch.py diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 9bda258c..1556fe02 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -1,7 +1,19 @@ -from skops.io.exceptions import UntrustedTypesFoundException +from __future__ import annotations +import io +from contextlib import contextmanager +from typing import Any, Generator, Literal, Sequence -def check_type(module_name, type_name, trusted): +from ._trusted_types import PRIMITIVE_TYPE_NAMES +from ._utils import LoadContext, get_module +from .exceptions import UntrustedTypesFoundException + +NODE_TYPE_MAPPING = {} # type: ignore + + +def check_type( + module_name: str, type_name: str, trusted: Literal[True] | Sequence[str] +) -> bool: """Check if a type is safe to load. A type is safe to load only if it's present in the trusted list. @@ -14,7 +26,7 @@ def check_type(module_name, type_name, trusted): type_name : str The class name of the type. - trusted : bool, or list of str + trusted : True, or list of str If ``True``, the tree is considered safe. Otherwise trusted has to be a list of trusted types. @@ -28,7 +40,7 @@ def check_type(module_name, type_name, trusted): return module_name + "." + type_name in trusted -def audit_tree(tree, trusted): +def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None: """Audit a tree of nodes. A tree is safe if it only contains trusted types. Audit is skipped if @@ -39,7 +51,7 @@ def audit_tree(tree, trusted): tree : skops.io._dispatch.Node The tree to audit. - trusted : bool, or list of str + trusted : True, or list of str If ``True``, the tree is considered safe. Otherwise trusted has to be a list of trusted types names. @@ -59,3 +71,284 @@ def audit_tree(tree, trusted): unsafe -= set(trusted) if unsafe: raise UntrustedTypesFoundException(unsafe) + + +class UNINITIALIZED: + """Sentinel value to indicate that a value has not been initialized yet.""" + + +@contextmanager +def temp_setattr(obj: Any, **kwargs: Any) -> Generator[None, None, None]: + """Context manager to temporarily set attributes on an object.""" + existing_attrs = {k for k in kwargs.keys() if hasattr(obj, k)} + previous_values = {k: getattr(obj, k, None) for k in kwargs} + for k, v in kwargs.items(): + setattr(obj, k, v) + try: + yield + finally: + for k, v in previous_values.items(): + if k in existing_attrs: + setattr(obj, k, v) + else: + delattr(obj, k) + + +class Node: + """A node in the tree of objects. + + This class is a parent class for all nodes in the tree of objects. Each + type of object (e.g. dict, list, etc.) has its own subclass of Node. + + Each child class has to implement two methods: ``__init__`` and + ``_construct``. + + ``__init__`` takes care of traversing the state tree and to create the + corresponding ``Node`` objects. It has access to the ``load_context`` which + in turn has access to the source zip file. The child class's ``__init__`` + must load attributes into the ``children`` attribute, which is a + dictionary of ``{child_name: unloaded_value/Node/list/etc}``. The + ``get_unsafe_set`` should be able to parse and validate the values set + under the ``children`` attribute. Note that primitives are persisted as a + ``JsonNode``. + + ``_construct`` takes care of constructing the object. It is only called + once and the result is cached in ``construct`` which is implemented in this + class. All required data to construct an instance should be loaded during + ``__init__``. + + The separation of ``__init__`` and ``_construct`` is necessary because + audit methods are called after ``__init__`` and before ``construct``. + Therefore ``__init__`` should avoid creating any instances or importing + any modules, to avoid running potentially untrusted code. + + Parameters + ---------- + state : dict + A dict representing the state of the dumped object. + + load_context : LoadContext + The context of the loading process. + + trusted : bool or list of str, default=False + If ``True``, the object will be loaded without any security checks. If + ``False``, the object will be loaded only if there are only trusted + objects in the dumped file. If a list of strings, the object will be + loaded only if all of its required types are listed in ``trusted`` + or are trusted by default. + + memoize : bool, default=True + If ``True``, the object will be memoized in the load context, if it has + the ``__id__`` set. This is used to avoid loading the same object + multiple times. + """ + + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + memoize: bool = True, + ) -> None: + self.class_name, self.module_name = state["__class__"], state["__module__"] + self._is_safe = None + self._constructed = UNINITIALIZED + saved_id = state.get("__id__") + if saved_id and memoize: + # hold reference to obj in case same instance encountered again in + # save state + load_context.memoize(self, saved_id) + + # subclasses should always: + # 1. call super().__init__() + # 2. set self.trusted = self._get_trusted(trusted, ...) where ... is a + # list of appropriate trusted types + # 3. set self.children, where children are states of child nodes; do not + # construct the children objects yet + self.trusted = self._get_trusted(trusted, []) + self.children: dict[str, Any] = {} + + def construct(self) -> Any: + """Construct the object. + + We only construct the object once, and then cache the result. + """ + if self._constructed is not UNINITIALIZED: + return self._constructed + self._constructed = self._construct() + return self._constructed + + def _construct(self) -> Any: + raise NotImplementedError( + f"{self.__class__.__name__} should implement a 'construct' method" + ) + + @staticmethod + def _get_trusted( + trusted: bool | Sequence[str], default: list[str] + ) -> Literal[True] | list[str]: + """Return a trusted list, or True. + + If ``trusted`` is ``False``, we return the ``default``, otherwise the + ``trusted`` value is used. + + This is a convenience method called by child classes. + """ + if trusted is True: + # if trusted is True, we trust the node + return True + + if trusted is False: + # if trusted is False, we only trust the defaults + return default + + # otherwise we trust the given list, call list in case it's a tuple + return list(trusted) + + def is_self_safe(self) -> bool: + """True only if the node's type is considered safe. + + This property only checks the type of the node, not its children. + """ + return check_type(self.module_name, self.class_name, self.trusted) + + def is_safe(self) -> bool: + """True only if the node and all its children are safe.""" + # if trusted is set to True, we don't do any safety checks. + if self.trusted is True: + return True + + return len(self.get_unsafe_set()) == 0 + + def get_unsafe_set(self) -> set[str]: + """Get the set of unsafe types. + + This method returns all types which are not trusted, including this + node and all its children. + + Returns + ------- + unsafe_set : set + A set of unsafe types. + """ + if hasattr(self, "_computing_unsafe_set"): + # this means we're already computing this node's unsafe set, so we + # return an empty set and let the computation of the parent node + # continue. This is to avoid infinite recursion. + return set() + + with temp_setattr(self, _computing_unsafe_set=True): + res = set() + if not self.is_self_safe(): + res.add(self.module_name + "." + self.class_name) + + for child in self.children.values(): + if child is None: + continue + + # Get the safety set based on the type of the child. In most cases + # other than ListNode and DictNode, children are all of type Node. + if isinstance(child, list): + # iterate through the list + for value in child: + res.update(value.get_unsafe_set()) + elif isinstance(child, dict): + # iterate through the values of the dict only + # TODO: should we check the types of the keys? + for value in child.values(): + res.update(value.get_unsafe_set()) + elif isinstance(child, Node): + # delegate to the child Node + res.update(child.get_unsafe_set()) + elif type(child) is type: + # the if condition bellow is not merged with the previous + # one because if the above condition is True, the following + # conditions about BytesIO, etc should be ignored. + if not check_type(get_module(child), child.__name__, self.trusted): + # if the child is a type, we check its safety + res.add(get_module(child) + "." + child.__name__) + elif isinstance(child, io.BytesIO): + # We trust BytesIO objects, which are read by other + # libraries such as numpy, scipy. + continue + elif check_type( + get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES + ): + # if the child is a primitive type, we don't need to check its + # safety. + continue + else: + raise ValueError( + f"Cannot determine the safety of type {type(child)}. Please" + " open an issue at https://github.com/skops-dev/skops/issues" + " for us to fix the issue." + ) + + return res + + +class CachedNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool = False, + ): + # we pass memoize as False because we don't want to memoize the cached + # node. + super().__init__(state, load_context, trusted, memoize=False) + self.trusted = True + # TODO: deal with case that __id__ is unknown or prevent it from + # happening + self.cached = load_context.get_object(state.get("__id__")) # type: ignore + self.children = {} # type: ignore + + def _construct(self): + # TODO: FIXME This causes a recursion error when loading a cached + # object if we call the cached object's `construct``. Some refactoring + # is needed to fix this. + return self.cached.construct() + + +NODE_TYPE_MAPPING["CachedNode"] = CachedNode + + +def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: + """Get the tree of nodes. + + This function returns the root node of the tree of nodes. The tree is + constructed recursively by traversing the state tree. No instances are + created during this process. One would need to call ``construct`` on the + root node to create the instances. + + This function also handles memoization of the nodes. If a node has already + been created, it is returned instead of creating a new one. + + Parameters + ---------- + state : dict + The state of the dumped object. + + load_context : LoadContext + The context of the loading process. + """ + saved_id = state.get("__id__") + if saved_id in load_context.memo: + # This means the node is already loaded, so we return it. Note that the + # node is not constructed at this point. It will be constructed when + # the parent node's ``construct`` method is called, and for this node + # it'll be called more than once. But that's not an issue since the + # node's ``construct`` method caches the instance. + return load_context.get_object(saved_id) + + try: + node_cls = NODE_TYPE_MAPPING[state["__loader__"]] + except KeyError: + type_name = f"{state['__module__']}.{state['__class__']}" + raise TypeError( + f" Can't find loader {state['__loader__']} for type {type_name}." + ) + + loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore + + return loaded_tree diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py deleted file mode 100644 index 179bd4b2..00000000 --- a/skops/io/_dispatch.py +++ /dev/null @@ -1,263 +0,0 @@ -from __future__ import annotations - -import io -from contextlib import contextmanager - -from ._audit import check_type -from ._trusted_types import PRIMITIVE_TYPE_NAMES -from ._utils import LoadContext, get_module - -NODE_TYPE_MAPPING = {} # type: ignore - - -class UNINITIALIZED: - """Sentinel value to indicate that a value has not been initialized yet.""" - - -@contextmanager -def temp_setattr(obj, **kwargs): - """Context manager to temporarily set attributes on an object.""" - existing_attrs = {k for k in kwargs.keys() if hasattr(obj, k)} - previous_values = {k: getattr(obj, k, None) for k in kwargs} - for k, v in kwargs.items(): - setattr(obj, k, v) - try: - yield - finally: - for k, v in previous_values.items(): - if k in existing_attrs: - setattr(obj, k, v) - else: - delattr(obj, k) - - -class Node: - """A node in the tree of objects. - - This class is a parent class for all nodes in the tree of objects. Each - type of object (e.g. dict, list, etc.) has its own subclass of Node. - - Each child class has to implement two methods: ``__init__`` and - ``_construct``. - - ``__init__`` takes care of traversing the state tree and to create the - corresponding ``Node`` objects. It has access to the ``load_context`` which - in turn has access to the source zip file. The child class's ``__init__`` - must load attributes into the ``children`` attribute, which is a - dictionary of ``{child_name: unloaded_value/Node/list/etc}``. The - ``get_unsafe_set`` should be able to parse and validate the values set - under the ``children`` attribute. Note that primitives are persisted as a - ``JsonNode``. - - ``_construct`` takes care of constructing the object. It is only called - once and the result is cached in ``construct`` which is implemented in this - class. All required data to construct an instance should be loaded during - ``__init__``. - - The separation of ``__init__`` and ``_construct`` is necessary because - audit methods are called after ``__init__`` and before ``construct``. - Therefore ``__init__`` should avoid creating any instances or importing - any modules, to avoid running potentially untrusted code. - - Parameters - ---------- - state : dict - A dict representing the state of the dumped object. - - load_context : LoadContext - The context of the loading process. - - trusted : bool or list of str, default=False - If ``True``, the object will be loaded without any security checks. If - ``False``, the object will be loaded only if there are only trusted - objects in the dumped file. If a list of strings, the object will be - loaded only if all of its required types are listed in ``trusted`` - or are trusted by default. - - memoize : bool, default=True - If ``True``, the object will be memoized in the load context, if it has - the ``__id__`` set. This is used to avoid loading the same object - multiple times. - """ - - def __init__(self, state, load_context: LoadContext, trusted=False, memoize=True): - self.class_name, self.module_name = state["__class__"], state["__module__"] - self.trusted = trusted - self._is_safe = None - self._constructed = UNINITIALIZED - saved_id = state.get("__id__") - if saved_id and memoize: - # hold reference to obj in case same instance encountered again in - # save state - load_context.memoize(self, saved_id) - - def construct(self): - """Construct the object. - - We only construct the object once, and then cache the result. - """ - if self._constructed is not UNINITIALIZED: - return self._constructed - self._constructed = self._construct() - return self._constructed - - @staticmethod - def _get_trusted(trusted, default): - """Return a trusted list, or True. - - If ``trusted`` is ``False``, we return the ``default``, otherwise the - ``trusted`` value is used. - - This is a convenience method called by child classes. - """ - if trusted is True: - # if trusted is True, we trust the node - return True - - if trusted is False: - # if trusted is False, we only trust the defaults - return default - - # otherwise we trust the given list - return trusted - - def is_self_safe(self): - """True only if the node's type is considered safe. - - This property only checks the type of the node, not its children. - """ - return check_type(self.module_name, self.class_name, self.trusted) - - def is_safe(self): - """True only if the node and all its children are safe.""" - # if trusted is set to True, we don't do any safety checks. - if self.trusted is True: - return True - - return len(self.get_unsafe_set()) == 0 - - def get_unsafe_set(self): - """Get the set of unsafe types. - - This method returns all types which are not trusted, including this - node and all its children. - - Returns - ------- - unsafe_set : set - A set of unsafe types. - """ - if hasattr(self, "_computing_unsafe_set"): - # this means we're already computing this node's unsafe set, so we - # return an empty set and let the computation of the parent node - # continue. This is to avoid infinite recursion. - return set() - - with temp_setattr(self, _computing_unsafe_set=True): - res = set() - if not self.is_self_safe(): - res.add(self.module_name + "." + self.class_name) - - for child in self.children.values(): - if child is None: - continue - - # Get the safety set based on the type of the child. In most cases - # other than ListNode and DictNode, children are all of type Node. - if isinstance(child, list): - # iterate through the list - for value in child: - res.update(value.get_unsafe_set()) - elif isinstance(child, dict): - # iterate through the values of the dict only - # TODO: should we check the types of the keys? - for value in child.values(): - res.update(value.get_unsafe_set()) - elif isinstance(child, Node): - # delegate to the child Node - res.update(child.get_unsafe_set()) - elif type(child) is type: - # the if condition bellow is not merged with the previous - # one because if the above condition is True, the following - # conditions about BytesIO, etc should be ignored. - if not check_type(get_module(child), child.__name__, self.trusted): - # if the child is a type, we check its safety - res.add(get_module(child) + "." + child.__name__) - elif isinstance(child, io.BytesIO): - # We trust BytesIO objects, which are read by other - # libraries such as numpy, scipy. - continue - elif check_type( - get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES - ): - # if the child is a primitive type, we don't need to check its - # safety. - continue - else: - raise ValueError( - f"Cannot determine the safety of type {type(child)}. Please" - " open an issue at https://github.com/skops-dev/skops/issues" - " for us to fix the issue." - ) - - return res - - -class CachedNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): - # we pass memoize as False because we don't want to memoize the cached - # node. - super().__init__(state, load_context, trusted, memoize=False) - self.trusted = True - self.cached = load_context.get_object(state.get("__id__")) - self.children = {} # type: ignore - - def _construct(self): - # TODO: FIXME This causes a recursion error when loading a cached - # object if we call the cached object's `construct``. Some refactoring - # is needed to fix this. - return self.cached.construct() - - -NODE_TYPE_MAPPING["CachedNode"] = CachedNode - - -def get_tree(state, load_context: LoadContext): - """Get the tree of nodes. - - This function returns the root node of the tree of nodes. The tree is - constructed recursively by traversing the state tree. No instances are - created during this process. One would need to call ``construct`` on the - root node to create the instances. - - This function also handles memoization of the nodes. If a node has already - been created, it is returned instead of creating a new one. - - Parameters - ---------- - state : dict - The state of the dumped object. - - load_context : LoadContext - The context of the loading process. - """ - saved_id = state.get("__id__") - if saved_id in load_context.memo: - # This means the node is already loaded, so we return it. Note that the - # node is not constructed at this point. It will be constructed when - # the parent node's ``construct`` method is called, and for this node - # it'll be called more than once. But that's not an issue since the - # node's ``construct`` method caches the instance. - return load_context.get_object(saved_id) - - try: - node_cls = NODE_TYPE_MAPPING[state["__loader__"]] - except KeyError: - type_name = f"{state['__module__']}.{state['__class__']}" - raise TypeError( - f" Can't find loader {state['__loader__']} for type {type_name}." - ) - - loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore - - return loaded_tree diff --git a/skops/io/_general.py b/skops/io/_general.py index b6449f0d..e90627e5 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -3,11 +3,11 @@ import json from functools import partial from types import FunctionType, MethodType -from typing import Any +from typing import Any, Sequence import numpy as np -from ._dispatch import Node, get_tree +from ._audit import Node, get_tree from ._trusted_types import PRIMITIVE_TYPE_NAMES from ._utils import ( LoadContext, @@ -42,7 +42,12 @@ def dict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class DictNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.dict"]) self.children = { @@ -53,7 +58,7 @@ def __init__(self, state, load_context: LoadContext, trusted=False): }, } - def _construct(self): + def _construct(self) -> Any: content = gettype(self.module_name, self.class_name)() key_types = self.children["key_types"].construct() for k_type, (key, val) in zip(key_types, self.children["content"].items()): @@ -74,7 +79,12 @@ def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class ListNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.list"]) self.children = { @@ -98,14 +108,19 @@ def set_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class SetNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.set"]) self.children = { "content": [get_tree(value, load_context) for value in state["content"]] } - def _construct(self): + def _construct(self) -> Any: content_type = gettype(self.module_name, self.class_name) return content_type([item.construct() for item in self.children["content"]]) @@ -122,14 +137,19 @@ def tuple_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class TupleNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.tuple"]) self.children = { "content": [get_tree(value, load_context) for value in state["content"]] } - def _construct(self): + def _construct(self) -> Any: # Returns a tuple or a namedtuple instance. cls = gettype(self.module_name, self.class_name) @@ -139,7 +159,7 @@ def _construct(self): return cls(*content) return content - def isnamedtuple(self, t): + def isnamedtuple(self, t) -> bool: # This is needed since namedtuples need to have the args when # initialized. b = t.__bases__ @@ -165,30 +185,32 @@ def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class FunctionNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) self.children = {"content": state["content"]} - def _construct(self): + def _construct(self) -> Any: return _import_obj( self.children["content"]["module_path"], self.children["content"]["function"], ) - def _get_function_name(self): + def _get_function_name(self) -> str: return ( self.children["content"]["module_path"] + "." + self.children["content"]["function"] ) - def is_safe(self): - return self._get_function_name() in self.trusted - - def get_unsafe_set(self): - if self.is_safe(): + def get_unsafe_set(self) -> set[str]: + if self.trusted is True: return set() return {self._get_function_name()} @@ -211,7 +233,12 @@ def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class PartialNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) # TODO: should we trust anything? self.trusted = self._get_trusted(trusted, []) @@ -222,13 +249,14 @@ def __init__(self, state, load_context: LoadContext, trusted=False): "namespace": get_tree(state["content"]["namespace"], load_context), } - def _construct(self): + def _construct(self) -> Any: func = self.children["func"].construct() args = self.children["args"].construct() kwds = self.children["kwds"].construct() namespace = self.children["namespace"].construct() instance = partial(func, *args, **kwds) # always use partial, not a subclass - instance.__setstate__((func, args, kwds, namespace)) + # partial always has __setstate__ + instance.__setstate__((func, args, kwds, namespace)) # type: ignore return instance @@ -244,15 +272,20 @@ def type_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class TypeNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, PRIMITIVE_TYPE_NAMES) # We use a bare Node type here since a Node only checks the type in the # dict using __class__ and __module__ keys. - self.children = {} # type: ignore + self.children = {} - def _construct(self): + def _construct(self) -> Any: return _import_obj(self.module_name, self.class_name) @@ -271,7 +304,12 @@ def slice_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class SliceNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["builtins.slice"]) self.children = { @@ -280,7 +318,7 @@ def __init__(self, state, load_context: LoadContext, trusted=False): "step": state["content"]["step"], } - def _construct(self): + def _construct(self) -> Any: return slice( self.children["start"], self.children["stop"], self.children["step"] ) @@ -329,11 +367,17 @@ def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class ObjectNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) - if "content" in state: - attrs = get_tree(state.get("content"), load_context) + content = state.get("content") + if content is not None: + attrs = get_tree(content, load_context) else: attrs = None @@ -341,13 +385,14 @@ def __init__(self, state, load_context: LoadContext, trusted=False): # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) - def _construct(self): + def _construct(self) -> Any: cls = gettype(self.module_name, self.class_name) # Instead of simply constructing the instance, we use __new__, which - # bypasses the __init__, and then we set the attributes. This solves - # the issue of required init arguments. - instance = cls.__new__(cls) + # bypasses the __init__, and then we set the attributes. This solves the + # issue of required init arguments. Note that the instance created here + # might not be valid until all its attributes have been set below. + instance = cls.__new__(cls) # type: ignore if not self.children["attrs"]: # nothing more to do @@ -362,7 +407,7 @@ def _construct(self): return instance -def method_get_state(obj: Any, save_context: SaveContext): +def method_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # This method is used to persist bound methods, which are # dependent on a specific instance of an object. # It stores the state of the object the method is bound to, @@ -380,7 +425,12 @@ def method_get_state(obj: Any, save_context: SaveContext): class MethodNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = { "obj": get_tree(state["content"]["obj"], load_context), @@ -389,7 +439,7 @@ def __init__(self, state, load_context: LoadContext, trusted=False): # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) - def _construct(self): + def _construct(self) -> Any: loaded_obj = self.children["obj"].construct() method = getattr(loaded_obj, self.children["func"]) return method @@ -400,22 +450,28 @@ def unsupported_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any] class JsonNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.content = state["content"] + self.children = {} - def is_safe(self): + def is_safe(self) -> bool: # JsonNode is always considered safe. # TODO: should we consider a JsonNode always safe? return True - def is_self_safe(self): + def is_self_safe(self) -> bool: return True - def get_unsafe_set(self): + def get_unsafe_set(self) -> set[str]: return set() - def _construct(self): + def _construct(self) -> Any: return json.loads(self.content) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index cc31e6c8..274daea5 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -1,11 +1,11 @@ from __future__ import annotations import io -from typing import Any +from typing import Any, Sequence import numpy as np -from ._dispatch import Node, get_tree +from ._audit import Node, get_tree from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -50,7 +50,12 @@ def ndarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class NdArrayNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.type = state["type"] self.trusted = self._get_trusted(trusted, ["numpy.ndarray"]) @@ -68,7 +73,7 @@ def __init__(self, state, load_context: LoadContext, trusted=False): else: raise ValueError(f"Unknown type {self.type}.") - def _construct(self): + def _construct(self) -> Any: # Dealing with a regular numpy array, where dtype != object if self.type == "numpy": content = np.load(self.children["content"], allow_pickle=False) @@ -110,7 +115,12 @@ def maskedarray_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any] class MaskedArrayNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = self._get_trusted(trusted, ["numpy.ma.MaskedArray"]) self.children = { @@ -118,7 +128,7 @@ def __init__(self, state, load_context: LoadContext, trusted=False): "mask": get_tree(state["content"]["mask"], load_context), } - def _construct(self): + def _construct(self) -> Any: data = self.children["data"].construct() mask = self.children["mask"].construct() return np.ma.MaskedArray(data, mask) @@ -136,12 +146,17 @@ def random_state_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any class RandomStateNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = {"content": get_tree(state["content"], load_context)} self.trusted = self._get_trusted(trusted, ["numpy.random.RandomState"]) - def _construct(self): + def _construct(self) -> Any: random_state = gettype(self.module_name, self.class_name)() random_state.set_state(self.children["content"].construct()) return random_state @@ -159,12 +174,17 @@ def random_generator_get_state(obj: Any, save_context: SaveContext) -> dict[str, class RandomGeneratorNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = {"bit_generator_state": state["content"]["bit_generator"]} self.trusted = self._get_trusted(trusted, ["numpy.random.Generator"]) - def _construct(self): + def _construct(self) -> Any: # first restore the state of the bit generator bit_generator = gettype( "numpy.random", self.children["bit_generator_state"]["bit_generator"] @@ -205,13 +225,18 @@ def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class DTypeNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.children = {"content": get_tree(state["content"], load_context)} # TODO: what should we trust? self.trusted = self._get_trusted(trusted, []) - def _construct(self): + def _construct(self) -> Any: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. return self.children["content"].construct().dtype diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 4d039936..69952161 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -3,12 +3,13 @@ import importlib import io import json +from pathlib import Path +from typing import Any, Sequence from zipfile import ZipFile import skops -from ._audit import audit_tree -from ._dispatch import NODE_TYPE_MAPPING, get_tree +from ._audit import NODE_TYPE_MAPPING, audit_tree, get_tree from ._utils import LoadContext, SaveContext, _get_state, get_state # We load the dispatch functions from the corresponding modules and register @@ -23,7 +24,7 @@ NODE_TYPE_MAPPING.update(module.NODE_TYPE_MAPPING) -def _save(obj): +def _save(obj: Any) -> io.BytesIO: buffer = io.BytesIO() with ZipFile(buffer, "w") as zip_file: @@ -38,7 +39,7 @@ def _save(obj): return buffer -def dump(obj, file): +def dump(obj: Any, file: str) -> None: """Save an object using the skops persistence format. Skops aims at providing a secure persistence feature that does not rely on @@ -68,7 +69,7 @@ def dump(obj, file): f.write(buffer.getbuffer()) -def dumps(obj): +def dumps(obj: Any) -> bytes: """Save an object using the skops persistence format as a bytes object. .. warning:: @@ -88,7 +89,7 @@ def dumps(obj): return buffer.getbuffer().tobytes() -def load(file, trusted=False): +def load(file: str | Path, trusted: bool | Sequence[str] = False) -> Any: """Load an object saved with the skops persistence format. Skops aims at providing a secure persistence feature that does not rely on @@ -104,7 +105,7 @@ def load(file, trusted=False): Parameters ---------- - file: str + file: str or pathlib.Path The file name of the object to be loaded. trusted: bool, or list of str, default=False @@ -130,7 +131,7 @@ def load(file, trusted=False): return instance -def loads(data, trusted=False): +def loads(data: bytes, trusted: bool | Sequence[str] = False) -> Any: """Load an object saved with the skops persistence format from a bytes object. @@ -171,7 +172,9 @@ def loads(data, trusted=False): return instance -def get_untrusted_types(*, data=None, file=None): +def get_untrusted_types( + *, data: bytes | None = None, file: str | Path | None = None +) -> list[str]: """Get a list of untrusted types in a skops dump. Parameters @@ -193,11 +196,15 @@ def get_untrusted_types(*, data=None, file=None): """ if data and file: raise ValueError("Only one of data or file should be passed.") + if not data and not file: + raise ValueError("Exactly one of data or file should be passed.") + content: io.BytesIO | str | Path if data: content = io.BytesIO(data) else: - content = file + # mypy doesn't understand that file cannot be None here, thus ignore + content = file # type: ignore with ZipFile(content, "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 223be95d..a348a2fb 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -1,11 +1,11 @@ from __future__ import annotations import io -from typing import Any +from typing import Any, Sequence from scipy.sparse import load_npz, save_npz, spmatrix -from ._dispatch import Node +from ._audit import Node from ._utils import LoadContext, SaveContext, get_module @@ -33,7 +33,12 @@ def sparse_matrix_get_state(obj: Any, save_context: SaveContext) -> dict[str, An class SparseMatrixNode(Node): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) type = state["type"] self.trusted = self._get_trusted(trusted, ["scipy.sparse.spmatrix"]) @@ -44,7 +49,7 @@ def __init__(self, state, load_context: LoadContext, trusted=False): self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))} - def _construct(self): + def _construct(self) -> Any: # scipy load_npz uses numpy.save with allow_pickle=False under the # hood, so we're safe using it return load_npz(self.children["content"]) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 4a57eea0..c28542f0 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Sequence, Type from sklearn.cluster import Birch @@ -22,7 +22,7 @@ ) from sklearn.tree._tree import Tree -from ._dispatch import Node, get_tree +from ._audit import Node, get_tree from ._general import unsupported_get_state from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -88,7 +88,13 @@ def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class ReduceNode(Node): - def __init__(self, state, load_context: LoadContext, constructor, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + constructor: Type[Any], + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) reduce = state["__reduce__"] self.children = { @@ -97,9 +103,10 @@ def __init__(self, state, load_context: LoadContext, constructor, trusted=False) "constructor": constructor, } - def _construct(self): + def _construct(self) -> Any: args = self.children["args"].construct() - instance = self.children["constructor"](*args) + constructor = self.children["constructor"] + instance = constructor(*args) attrs = self.children["attrs"].construct() if not attrs: # nothing more to do @@ -107,7 +114,7 @@ def _construct(self): if isinstance(args, tuple) and not hasattr(instance, "__setstate__"): raise UnsupportedTypeException( - f"Objects of type {self.constructor} are not supported yet" + f"Objects of type {constructor} are not supported yet" ) if hasattr(instance, "__setstate__"): @@ -125,7 +132,12 @@ def tree_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class TreeNode(ReduceNode): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, constructor=Tree, trusted=trusted) self.trusted = self._get_trusted(trusted, [get_module(Tree) + ".Tree"]) @@ -137,12 +149,17 @@ def sgd_loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: class SGDNode(ReduceNode): - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: # TODO: make sure trusted here makes sense and used. super().__init__( state, load_context, - constructor=gettype(state.get("__module__"), state.get("__class__")), + constructor=gettype(state["__module__"], state["__class__"]), trusted=False, ) self.trusted = self._get_trusted( @@ -173,7 +190,12 @@ def _DictWithDeprecatedKeys_get_state( # TODO: remove once support for sklearn<1.2 is dropped. class _DictWithDeprecatedKeysNode(Node): # _DictWithDeprecatedKeys is just a wrapper for dict - def __init__(self, state, load_context: LoadContext, trusted=False): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: bool | Sequence[str] = False, + ) -> None: super().__init__(state, load_context, trusted) self.trusted = [ get_module(_DictWithDeprecatedKeysNode) + "._DictWithDeprecatedKeys" @@ -185,7 +207,7 @@ def __init__(self, state, load_context: LoadContext, trusted=False): ), } - def _construct(self): + def _construct(self) -> Any: instance = _DictWithDeprecatedKeys(**self.children["main"].construct()) instance._deprecated_key_to_new_key = self.children[ "_deprecated_key_to_new_key" diff --git a/skops/io/_utils.py b/skops/io/_utils.py index b1219233..37b036eb 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass, field from functools import singledispatch -from typing import Any +from typing import Any, Type from zipfile import ZipFile @@ -28,7 +28,7 @@ def _getattribute(obj, name): # This function is particularly used to detect the path of functions such as # ufuncs. It returns the full path, instead of returning the module name. -def whichmodule(obj, name): +def whichmodule(obj: Any, name: str) -> str: """Find the module an object belong to.""" module_name = getattr(obj, "__module__", None) if module_name is not None: @@ -53,17 +53,18 @@ def whichmodule(obj, name): # --------------------------------------------------------------------- -def _import_obj(module, cls_or_func, package=None): +def _import_obj(module: str, cls_or_func: str, package: str | None = None) -> Any: return getattr(importlib.import_module(module, package=package), cls_or_func) -def gettype(module_name, cls_or_func): +def gettype(module_name: str, cls_or_func: str) -> Type[Any]: if module_name and cls_or_func: return _import_obj(module_name, cls_or_func) - return None + raise ValueError(f"Object {cls_or_func} of module {module_name} is unknown") -def get_module(obj): + +def get_module(obj: Any) -> str: """Find module for given object If the module cannot be identified, it's assumed to be "__main__". @@ -144,14 +145,14 @@ def get_object(self, id: int) -> Any: @singledispatch -def _get_state(obj, save_context): +def _get_state(obj, save_context: SaveContext): # This function should never be called directly. Instead, it is used to # dispatch to the correct implementation of get_state for the given type of # its first argument. raise TypeError(f"Getting the state of type {type(obj)} is not supported yet") -def get_state(value, save_context): +def get_state(value, save_context: SaveContext) -> dict[str, Any]: # This is a helper function to try to get the state of an object. If it # fails with `get_state`, we try with json.dumps, if that fails, we raise # the original error alongside the json error. diff --git a/skops/io/tests/test_audit.py b/skops/io/tests/test_audit.py index aff04cff..35902b86 100644 --- a/skops/io/tests/test_audit.py +++ b/skops/io/tests/test_audit.py @@ -11,8 +11,7 @@ from sklearn.preprocessing import FunctionTransformer, StandardScaler from skops.io import dumps, get_untrusted_types -from skops.io._audit import audit_tree, check_type -from skops.io._dispatch import Node, get_tree, temp_setattr +from skops.io._audit import Node, audit_tree, check_type, get_tree, temp_setattr from skops.io._general import DictNode, dict_get_state from skops.io._utils import LoadContext, SaveContext, gettype @@ -102,9 +101,14 @@ def test_list_safety(values, is_safe): assert tree.is_safe() == is_safe -def test_gettype(): - # return None if one argument is None - assert gettype(module_name="test", cls_or_func=None) is None +def test_gettype_error(): + msg = "Object None of module test is unknown" + with pytest.raises(ValueError, match=msg): + gettype(module_name="test", cls_or_func=None) + + msg = "Object test of module None is unknown" + with pytest.raises(ValueError, match=msg): + gettype(module_name=None, cls_or_func="test") # ImportError if the module cannot be imported with pytest.raises(ImportError): @@ -115,9 +119,10 @@ def test_gettype(): "data, file, exception, message", [ ("not-none", "not-none", ValueError, "Only one of data or file"), + (None, None, ValueError, "Exactly one of data or file should be passed"), ("string", None, TypeError, "a bytes-like object is required, not 'str'"), ], - ids=["both", "string-data"], + ids=["both", "neither", "string-data"], ) def test_get_untrusted_types_validation(data, file, exception, message): with pytest.raises(exception, match=message): diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 7c9f5056..f833620a 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -54,7 +54,7 @@ import skops from skops.io import dump, dumps, get_untrusted_types, load, loads -from skops.io._dispatch import NODE_TYPE_MAPPING, get_tree +from skops.io._audit import NODE_TYPE_MAPPING, get_tree from skops.io._sklearn import UNSUPPORTED_TYPES from skops.io._utils import LoadContext, SaveContext, _get_state, get_state from skops.io.exceptions import UnsupportedTypeException From 1d610f94f376dca77340291815e1fb39666c36c3 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 28 Nov 2022 15:24:05 +0100 Subject: [PATCH 2/8] Python 3.7: Import Literal from fixes --- skops/io/_audit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 1556fe02..7160e725 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -2,8 +2,9 @@ import io from contextlib import contextmanager -from typing import Any, Generator, Literal, Sequence +from typing import Any, Generator, Sequence +from ..utils.fixes import Literal from ._trusted_types import PRIMITIVE_TYPE_NAMES from ._utils import LoadContext, get_module from .exceptions import UntrustedTypesFoundException From 975283d84f1edbc5eb1431985ba1e0392786321d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Nov 2022 12:06:34 +0100 Subject: [PATCH 3/8] Update skops/io/_audit.py Co-authored-by: Adrin Jalali --- skops/io/_audit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 7160e725..ed52c6f2 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -181,7 +181,7 @@ def construct(self) -> Any: def _construct(self) -> Any: raise NotImplementedError( - f"{self.__class__.__name__} should implement a 'construct' method" + f"{self.__class__.__name__} should implement a '_construct' method" ) @staticmethod From b9847271a4c56cea39a576fcc450781bb96255d6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Nov 2022 12:13:46 +0100 Subject: [PATCH 4/8] Type annotation: constructor can be a function too --- skops/io/_sklearn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index c28542f0..ee6912de 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Sequence, Type +from typing import Any, Callable, Sequence, Type from sklearn.cluster import Birch @@ -92,7 +92,7 @@ def __init__( self, state: dict[str, Any], load_context: LoadContext, - constructor: Type[Any], + constructor: Type[Any] | Callable[..., Any], trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) From b79032382ebf327752cc002def4ec7bc743a442f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Nov 2022 12:14:47 +0100 Subject: [PATCH 5/8] Add returns part of get_tree docstring --- skops/io/_audit.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index ed52c6f2..35eb6b5f 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -332,6 +332,11 @@ def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node: load_context : LoadContext The context of the loading process. + + Returns + ------- + loaded_tree : Node + The tree containing all its (non-instantiated) child nodes. """ saved_id = state.get("__id__") if saved_id in load_context.memo: From 8da4bddf66eebd7b5ca6701735fc948c7b534926 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Nov 2022 12:16:56 +0100 Subject: [PATCH 6/8] Fix an error in persistence documentation --- docs/persistence.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/persistence.rst b/docs/persistence.rst index 7866ea9d..7fced29b 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -85,7 +85,7 @@ In contrast to ``pickle``, skops cannot persist arbitrary Python code. This means if you have custom functions (say, a custom function to be used with :class:`sklearn.preprocessing.FunctionTransformer`), it will not work. However, most ``numpy`` and ``scipy`` functions should work. Therefore, you can actually -save built-in functions like``numpy.sqrt``. +save built-in functions like ``numpy.sqrt``. Roadmap ------- From 662b59397c528eaf20e544947d121be633e5e84e Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 30 Nov 2022 11:46:17 +0100 Subject: [PATCH 7/8] Remove Any return type annotation from _construct It is not adding value. --- skops/io/_audit.py | 4 ++-- skops/io/_general.py | 20 ++++++++++---------- skops/io/_numpy.py | 10 +++++----- skops/io/_scipy.py | 2 +- skops/io/_sklearn.py | 4 ++-- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 35eb6b5f..33d271fc 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -169,7 +169,7 @@ def __init__( self.trusted = self._get_trusted(trusted, []) self.children: dict[str, Any] = {} - def construct(self) -> Any: + def construct(self): """Construct the object. We only construct the object once, and then cache the result. @@ -179,7 +179,7 @@ def construct(self) -> Any: self._constructed = self._construct() return self._constructed - def _construct(self) -> Any: + def _construct(self): raise NotImplementedError( f"{self.__class__.__name__} should implement a '_construct' method" ) diff --git a/skops/io/_general.py b/skops/io/_general.py index e90627e5..940f8273 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -58,7 +58,7 @@ def __init__( }, } - def _construct(self) -> Any: + def _construct(self): content = gettype(self.module_name, self.class_name)() key_types = self.children["key_types"].construct() for k_type, (key, val) in zip(key_types, self.children["content"].items()): @@ -120,7 +120,7 @@ def __init__( "content": [get_tree(value, load_context) for value in state["content"]] } - def _construct(self) -> Any: + def _construct(self): content_type = gettype(self.module_name, self.class_name) return content_type([item.construct() for item in self.children["content"]]) @@ -149,7 +149,7 @@ def __init__( "content": [get_tree(value, load_context) for value in state["content"]] } - def _construct(self) -> Any: + def _construct(self): # Returns a tuple or a namedtuple instance. cls = gettype(self.module_name, self.class_name) @@ -196,7 +196,7 @@ def __init__( self.trusted = self._get_trusted(trusted, []) self.children = {"content": state["content"]} - def _construct(self) -> Any: + def _construct(self): return _import_obj( self.children["content"]["module_path"], self.children["content"]["function"], @@ -249,7 +249,7 @@ def __init__( "namespace": get_tree(state["content"]["namespace"], load_context), } - def _construct(self) -> Any: + def _construct(self): func = self.children["func"].construct() args = self.children["args"].construct() kwds = self.children["kwds"].construct() @@ -285,7 +285,7 @@ def __init__( # dict using __class__ and __module__ keys. self.children = {} - def _construct(self) -> Any: + def _construct(self): return _import_obj(self.module_name, self.class_name) @@ -318,7 +318,7 @@ def __init__( "step": state["content"]["step"], } - def _construct(self) -> Any: + def _construct(self): return slice( self.children["start"], self.children["stop"], self.children["step"] ) @@ -385,7 +385,7 @@ def __init__( # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) - def _construct(self) -> Any: + def _construct(self): cls = gettype(self.module_name, self.class_name) # Instead of simply constructing the instance, we use __new__, which @@ -439,7 +439,7 @@ def __init__( # TODO: what do we trust? self.trusted = self._get_trusted(trusted, []) - def _construct(self) -> Any: + def _construct(self): loaded_obj = self.children["obj"].construct() method = getattr(loaded_obj, self.children["func"]) return method @@ -471,7 +471,7 @@ def is_self_safe(self) -> bool: def get_unsafe_set(self) -> set[str]: return set() - def _construct(self) -> Any: + def _construct(self): return json.loads(self.content) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 274daea5..4676e5f0 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -73,7 +73,7 @@ def __init__( else: raise ValueError(f"Unknown type {self.type}.") - def _construct(self) -> Any: + def _construct(self): # Dealing with a regular numpy array, where dtype != object if self.type == "numpy": content = np.load(self.children["content"], allow_pickle=False) @@ -128,7 +128,7 @@ def __init__( "mask": get_tree(state["content"]["mask"], load_context), } - def _construct(self) -> Any: + def _construct(self): data = self.children["data"].construct() mask = self.children["mask"].construct() return np.ma.MaskedArray(data, mask) @@ -156,7 +156,7 @@ def __init__( self.children = {"content": get_tree(state["content"], load_context)} self.trusted = self._get_trusted(trusted, ["numpy.random.RandomState"]) - def _construct(self) -> Any: + def _construct(self): random_state = gettype(self.module_name, self.class_name)() random_state.set_state(self.children["content"].construct()) return random_state @@ -184,7 +184,7 @@ def __init__( self.children = {"bit_generator_state": state["content"]["bit_generator"]} self.trusted = self._get_trusted(trusted, ["numpy.random.Generator"]) - def _construct(self) -> Any: + def _construct(self): # first restore the state of the bit generator bit_generator = gettype( "numpy.random", self.children["bit_generator_state"]["bit_generator"] @@ -236,7 +236,7 @@ def __init__( # TODO: what should we trust? self.trusted = self._get_trusted(trusted, []) - def _construct(self) -> Any: + def _construct(self): # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. return self.children["content"].construct().dtype diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index a348a2fb..0e240b87 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -49,7 +49,7 @@ def __init__( self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))} - def _construct(self) -> Any: + def _construct(self): # scipy load_npz uses numpy.save with allow_pickle=False under the # hood, so we're safe using it return load_npz(self.children["content"]) diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index ee6912de..9da5392a 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -103,7 +103,7 @@ def __init__( "constructor": constructor, } - def _construct(self) -> Any: + def _construct(self): args = self.children["args"].construct() constructor = self.children["constructor"] instance = constructor(*args) @@ -207,7 +207,7 @@ def __init__( ), } - def _construct(self) -> Any: + def _construct(self): instance = _DictWithDeprecatedKeys(**self.children["main"].construct()) instance._deprecated_key_to_new_key = self.children[ "_deprecated_key_to_new_key" From 9358f36144680c5dcf9f461dec6640fbe14c0e73 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 1 Dec 2022 10:54:44 +0100 Subject: [PATCH 8/8] Add comment explaining Geneator type annotation --- skops/io/_audit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 33d271fc..6063dd21 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -78,6 +78,7 @@ class UNINITIALIZED: """Sentinel value to indicate that a value has not been initialized yet.""" +# Node: types for Generator mean: YieldType, SendType, ReturnType @contextmanager def temp_setattr(obj: Any, **kwargs: Any) -> Generator[None, None, None]: """Context manager to temporarily set attributes on an object."""